Skip to main content

parley/services/
ai_session.rs

1use std::collections::BTreeSet;
2use std::path::{Path, PathBuf};
3use std::process::Stdio;
4use std::sync::mpsc;
5use std::time::Duration;
6
7use anyhow::{Context, Result, anyhow};
8use include_dir::{Dir, include_dir};
9use serde::Serialize;
10use serde_json::Value;
11use tokio::fs;
12use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWriteExt, BufReader};
13use tokio::process::Command;
14use tokio::task::JoinHandle;
15use tokio::time::timeout;
16use tracing::{debug, error, info, warn};
17
18use crate::domain::ai::{AiProvider, AiSessionMode};
19use crate::domain::config::{AppConfig, PromptTransport};
20use crate::domain::diff::{DiffDocument, DiffFile, DiffHunk};
21use crate::domain::reference::parse_file_references;
22use crate::domain::review::{Author, CommentStatus, LineComment, ReviewState};
23use crate::git::diff::load_git_diff_head;
24use crate::services::review_service::{AddReplyInput, ReviewService};
25
26static AI_SESSION_PROMPTS_DIR: Dir<'_> = include_dir!("$CARGO_MANIFEST_DIR/prompts/ai_session");
27
28#[derive(Debug, Clone)]
29pub struct RunAiSessionInput {
30    pub review_name: String,
31    pub provider: AiProvider,
32    pub comment_ids: Vec<u64>,
33    pub mode: AiSessionMode,
34}
35
36#[derive(Debug, Clone, Serialize)]
37#[serde(rename_all = "snake_case")]
38pub struct AiSessionResult {
39    pub review_name: String,
40    pub provider: String,
41    pub mode: String,
42    pub client: String,
43    pub model: Option<String>,
44    pub session_id: String,
45    pub processed: usize,
46    pub skipped: usize,
47    pub failed: usize,
48    pub items: Vec<AiSessionItemResult>,
49}
50
51#[derive(Debug, Clone, Serialize)]
52#[serde(rename_all = "snake_case")]
53pub struct AiSessionItemResult {
54    pub comment_id: u64,
55    pub status: String,
56    pub message: String,
57}
58
59#[derive(Debug, Clone, Serialize)]
60#[serde(rename_all = "snake_case")]
61pub struct AiProgressEvent {
62    pub timestamp_ms: u64,
63    pub provider: String,
64    pub stream: String,
65    pub message: String,
66}
67
68#[derive(Debug, Clone)]
69struct ProviderInvocation {
70    reply: String,
71    model: Option<String>,
72}
73
74pub fn default_ai_session_mode(comment_ids: &[u64]) -> AiSessionMode {
75    if comment_ids.is_empty() {
76        AiSessionMode::Refactor
77    } else {
78        AiSessionMode::Reply
79    }
80}
81
82pub async fn run_ai_session(
83    service: &ReviewService,
84    input: RunAiSessionInput,
85) -> Result<AiSessionResult> {
86    run_ai_session_inner(service, input, None).await
87}
88
89pub async fn run_ai_session_with_progress(
90    service: &ReviewService,
91    input: RunAiSessionInput,
92    progress_sender: mpsc::Sender<AiProgressEvent>,
93) -> Result<AiSessionResult> {
94    run_ai_session_inner(service, input, Some(progress_sender)).await
95}
96
97async fn run_ai_session_inner(
98    service: &ReviewService,
99    input: RunAiSessionInput,
100    progress_sender: Option<mpsc::Sender<AiProgressEvent>>,
101) -> Result<AiSessionResult> {
102    info!(
103        review = %input.review_name,
104        provider = %input.provider.as_str(),
105        requested_comments = input.comment_ids.len(),
106        "starting ai session"
107    );
108    let config = service.load_config().await?;
109    let mut review = service.load_review(&input.review_name).await?;
110    let diff_document = match load_git_diff_head().await {
111        Ok(document) => Some(document),
112        Err(error) => {
113            warn!(error = %error, "ai session prompt context: unable to load git diff");
114            None
115        }
116    };
117    let now_ms = now_ms()?;
118    let provider_cfg = config.ai.provider_config(input.provider);
119    let mut result = AiSessionResult {
120        review_name: input.review_name.clone(),
121        provider: input.provider.as_str().to_string(),
122        mode: input.mode.as_str().to_string(),
123        client: provider_cfg.client.clone(),
124        model: provider_cfg.model.clone(),
125        session_id: format!("{}-{}-{now_ms}", input.review_name, input.provider.as_str()),
126        processed: 0,
127        skipped: 0,
128        failed: 0,
129        items: Vec::new(),
130    };
131
132    if matches!(review.state, ReviewState::Done) {
133        warn!(
134            review = %input.review_name,
135            provider = %input.provider.as_str(),
136            "ai session skipped because review is done"
137        );
138        result.items.push(AiSessionItemResult {
139            comment_id: 0,
140            status: "skipped".to_string(),
141            message: "review is done; ai session ignored".to_string(),
142        });
143        result.skipped = 1;
144        return Ok(result);
145    }
146
147    let target_ids: Vec<u64> = if input.comment_ids.is_empty() {
148        review
149            .comments
150            .iter()
151            .filter(|comment| comment_is_targetable(comment.status.clone(), input.mode))
152            .map(|comment| comment.id)
153            .collect()
154    } else {
155        input.comment_ids.clone()
156    };
157    let total_targets = target_ids.len();
158    if total_targets == 0 {
159        result.items.push(AiSessionItemResult {
160            comment_id: 0,
161            status: "skipped".to_string(),
162            message: match input.mode {
163                AiSessionMode::Reply => "no replyable threads to process".to_string(),
164                AiSessionMode::Refactor => "no open threads to process".to_string(),
165            },
166        });
167        result.skipped = 1;
168        emit_progress(
169            progress_sender.as_ref(),
170            input.provider,
171            "system",
172            "no open threads to process",
173        );
174        return Ok(result);
175    }
176
177    let explicit_selection = !input.comment_ids.is_empty();
178    for (step_index, comment_id) in target_ids.into_iter().enumerate() {
179        emit_progress(
180            progress_sender.as_ref(),
181            input.provider,
182            "system",
183            format!(
184                "thread #{comment_id}: start ({}/{})",
185                step_index + 1,
186                total_targets
187            ),
188        );
189        debug!(
190            review = %input.review_name,
191            provider = %input.provider.as_str(),
192            comment_id,
193            "processing ai thread"
194        );
195        let maybe_comment = review
196            .comments
197            .iter()
198            .find(|comment| comment.id == comment_id);
199        let Some(comment) = maybe_comment else {
200            warn!(
201                review = %input.review_name,
202                provider = %input.provider.as_str(),
203                comment_id,
204                "ai session target comment not found"
205            );
206            result.failed += 1;
207            result.items.push(AiSessionItemResult {
208                comment_id,
209                status: "failed".to_string(),
210                message: "comment not found in review".to_string(),
211            });
212            emit_progress(
213                progress_sender.as_ref(),
214                input.provider,
215                "system",
216                format!("thread #{comment_id}: failed (comment not found)"),
217            );
218            continue;
219        };
220
221        let allow_selected_reply = explicit_selection && matches!(input.mode, AiSessionMode::Reply);
222        if !comment_is_targetable(comment.status.clone(), input.mode) && !allow_selected_reply {
223            debug!(
224                review = %input.review_name,
225                provider = %input.provider.as_str(),
226                comment_id,
227                status = ?comment.status,
228                "skipping non-targetable comment for selected mode"
229            );
230            result.skipped += 1;
231            result.items.push(AiSessionItemResult {
232                comment_id,
233                status: "skipped".to_string(),
234                message: format!(
235                    "comment status {:?} is not targetable for {} mode",
236                    comment.status,
237                    input.mode.as_str()
238                ),
239            });
240            emit_progress(
241                progress_sender.as_ref(),
242                input.provider,
243                "system",
244                format!(
245                    "thread #{comment_id}: skipped (status={:?})",
246                    comment.status
247                ),
248            );
249            continue;
250        }
251
252        let prompt = build_thread_prompt(
253            &input.review_name,
254            comment_id,
255            &review,
256            diff_document.as_ref(),
257            input.mode,
258        );
259        let provider_reply = match invoke_provider(
260            &config,
261            input.provider,
262            input.mode,
263            &prompt,
264            progress_sender.clone(),
265        )
266        .await
267        {
268            Ok(reply) => reply,
269            Err(error) => {
270                error!(
271                    review = %input.review_name,
272                    provider = %input.provider.as_str(),
273                    comment_id,
274                    error = %error,
275                    "provider invocation failed"
276                );
277                result.failed += 1;
278                result.items.push(AiSessionItemResult {
279                    comment_id,
280                    status: "failed".to_string(),
281                    message: format!("provider failed: {error}"),
282                });
283                emit_progress(
284                    progress_sender.as_ref(),
285                    input.provider,
286                    "system",
287                    format!("thread #{comment_id}: failed ({error})"),
288                );
289                continue;
290            }
291        };
292        let reply_body =
293            format_ai_reply_body(provider_reply.model.as_deref(), &provider_reply.reply);
294
295        let updated = match service
296            .add_reply(
297                &input.review_name,
298                AddReplyInput {
299                    comment_id,
300                    author: Author::Ai,
301                    body: reply_body,
302                },
303            )
304            .await
305        {
306            Ok(value) => value,
307            Err(error) => {
308                error!(
309                    review = %input.review_name,
310                    provider = %input.provider.as_str(),
311                    comment_id,
312                    error = %error,
313                    "failed to persist ai reply"
314                );
315                result.failed += 1;
316                result.items.push(AiSessionItemResult {
317                    comment_id,
318                    status: "failed".to_string(),
319                    message: format!("failed to persist ai reply: {error}"),
320                });
321                emit_progress(
322                    progress_sender.as_ref(),
323                    input.provider,
324                    "system",
325                    format!("thread #{comment_id}: failed (persist reply: {error})"),
326                );
327                continue;
328            }
329        };
330
331        review = updated;
332        result.processed += 1;
333        info!(
334            review = %input.review_name,
335            provider = %input.provider.as_str(),
336            comment_id,
337            "ai reply persisted"
338        );
339        result.items.push(AiSessionItemResult {
340            comment_id,
341            status: "processed".to_string(),
342            message: match input.mode {
343                AiSessionMode::Reply => "ai reply added".to_string(),
344                AiSessionMode::Refactor => {
345                    "ai reply added; thread status moved to pending_human".to_string()
346                }
347            },
348        });
349        emit_progress(
350            progress_sender.as_ref(),
351            input.provider,
352            "system",
353            format!(
354                "thread #{comment_id}: done ({}/{})",
355                step_index + 1,
356                total_targets
357            ),
358        );
359    }
360
361    info!(
362        review = %input.review_name,
363        provider = %input.provider.as_str(),
364        processed = result.processed,
365        skipped = result.skipped,
366        failed = result.failed,
367        "ai session completed"
368    );
369    Ok(result)
370}
371
372fn build_thread_prompt(
373    review_name: &str,
374    comment_id: u64,
375    review: &crate::domain::review::ReviewSession,
376    diff_document: Option<&DiffDocument>,
377    mode: AiSessionMode,
378) -> String {
379    let Some(comment) = review
380        .comments
381        .iter()
382        .find(|comment| comment.id == comment_id)
383    else {
384        return missing_comment_prompt(review_name, comment_id);
385    };
386
387    let mut thread = String::new();
388    thread.push_str(&format!("Review: {review_name}\n"));
389    thread.push_str(&format!(
390        "Thread comment id: {}\nFile: {}\nLine: {}:{}\nStatus: {:?}\n",
391        comment.id,
392        comment.file_path,
393        comment
394            .old_line
395            .map(|value| value.to_string())
396            .unwrap_or_else(|| "_".to_string()),
397        comment
398            .new_line
399            .map(|value| value.to_string())
400            .unwrap_or_else(|| "_".to_string()),
401        comment.status
402    ));
403    thread.push_str("\nOriginal comment:\n");
404    thread.push_str(&comment.body);
405    thread.push_str("\n\nReplies so far:\n");
406    if comment.replies.is_empty() {
407        thread.push_str("- (none)\n");
408    } else {
409        for reply in &comment.replies {
410            let author = match reply.author {
411                Author::User => "user",
412                Author::Ai => "ai",
413            };
414            thread.push_str(&format!("- {}: {}\n", author, reply.body));
415        }
416    }
417    append_target_file_and_diff_context(&mut thread, comment, diff_document);
418    append_referenced_files_context(&mut thread, comment);
419
420    match mode {
421        AiSessionMode::Reply => {
422            thread.push_str(prompt_template("reply_task.md"));
423        }
424        AiSessionMode::Refactor => {
425            thread.push_str(prompt_template("refactor_task.md"));
426        }
427    }
428    thread
429}
430
431fn append_target_file_and_diff_context(
432    prompt: &mut String,
433    comment: &LineComment,
434    diff_document: Option<&DiffDocument>,
435) {
436    prompt.push_str("\n\nPrimary target context:\n");
437    let target_line = comment.new_line.or(comment.old_line);
438    match target_line {
439        Some(line) => {
440            prompt.push_str(&format!(
441                "- thread anchor: {}:{}\n",
442                comment.file_path, line
443            ));
444            if let Some(resolved) = resolve_workspace_path(&comment.file_path) {
445                if let Some(snippet) = file_line_snippet(&resolved, line) {
446                    prompt.push_str(&format!(
447                        "  file snippet around {}:{}:\n{}",
448                        comment.file_path, line, snippet
449                    ));
450                } else {
451                    prompt.push_str("  file snippet: unavailable for requested line\n");
452                }
453            } else {
454                prompt.push_str("  file snippet: file not found in workspace\n");
455            }
456        }
457        None => {
458            prompt.push_str(&format!(
459                "- thread anchor: {} (line unavailable)\n",
460                comment.file_path
461            ));
462        }
463    }
464
465    if let Some(document) = diff_document {
466        if let Some(file) = find_diff_file(document, &comment.file_path) {
467            if let Some(hunk) = choose_best_hunk(file, comment.old_line, comment.new_line) {
468                let excerpt = format_hunk_excerpt(hunk, comment.old_line, comment.new_line, 28);
469                prompt.push_str("  nearest diff hunk:\n");
470                prompt.push_str(&excerpt);
471            } else {
472                prompt.push_str("  nearest diff hunk: none for this file\n");
473            }
474        } else {
475            prompt.push_str("  nearest diff hunk: file not present in current git diff\n");
476        }
477    } else {
478        prompt.push_str("  nearest diff hunk: unavailable (failed to load git diff)\n");
479    }
480}
481
482fn append_referenced_files_context(
483    prompt: &mut String,
484    comment: &crate::domain::review::LineComment,
485) {
486    let mut ordered = BTreeSet::new();
487    for reference in parse_file_references(&comment.body) {
488        ordered.insert((reference.path, reference.line));
489    }
490    for reply in &comment.replies {
491        for reference in parse_file_references(&reply.body) {
492            ordered.insert((reference.path, reference.line));
493        }
494    }
495    if ordered.is_empty() {
496        return;
497    }
498
499    prompt.push_str("\n\nReferenced files from thread mentions:\n");
500    for (path, line) in ordered.into_iter().take(8) {
501        let marker = if let Some(value) = line {
502            format!("{path}:{value}")
503        } else {
504            path.clone()
505        };
506        prompt.push_str(&format!("- {marker}\n"));
507        if let (Some(value), Some(resolved)) = (line, resolve_workspace_path(&path))
508            && let Some(snippet) = file_line_snippet(&resolved, value)
509        {
510            prompt.push_str(&format!("  context from {}:\n", resolved.display()));
511            prompt.push_str(&snippet);
512        }
513    }
514}
515
516fn find_diff_file<'a>(document: &'a DiffDocument, path: &str) -> Option<&'a DiffFile> {
517    document.files.iter().find(|file| file.path == path)
518}
519
520fn choose_best_hunk(
521    file: &DiffFile,
522    old_line: Option<u32>,
523    new_line: Option<u32>,
524) -> Option<&DiffHunk> {
525    if file.hunks.is_empty() {
526        return None;
527    }
528
529    for hunk in &file.hunks {
530        if hunk_contains_anchor(hunk, old_line, new_line) {
531            return Some(hunk);
532        }
533    }
534
535    let mut scored = file
536        .hunks
537        .iter()
538        .map(|hunk| (hunk_distance_to_anchor(hunk, old_line, new_line), hunk))
539        .collect::<Vec<_>>();
540    scored.sort_by_key(|(distance, _)| *distance);
541    scored.first().map(|(_, hunk)| *hunk)
542}
543
544fn hunk_contains_anchor(hunk: &DiffHunk, old_line: Option<u32>, new_line: Option<u32>) -> bool {
545    hunk.lines.iter().any(|line| {
546        old_line.is_some() && line.old_line == old_line
547            || new_line.is_some() && line.new_line == new_line
548    })
549}
550
551fn hunk_distance_to_anchor(hunk: &DiffHunk, old_line: Option<u32>, new_line: Option<u32>) -> u32 {
552    let mut best = u32::MAX;
553    if let Some(target_old) = old_line {
554        best = best.min(line_distance(hunk.old_start, target_old));
555    }
556    if let Some(target_new) = new_line {
557        best = best.min(line_distance(hunk.new_start, target_new));
558    }
559    if best == u32::MAX { 0 } else { best }
560}
561
562fn line_distance(base: u32, target: u32) -> u32 {
563    base.abs_diff(target)
564}
565
566fn format_hunk_excerpt(
567    hunk: &DiffHunk,
568    old_line: Option<u32>,
569    new_line: Option<u32>,
570    max_lines: usize,
571) -> String {
572    if hunk.lines.is_empty() || max_lines == 0 {
573        return String::new();
574    }
575    let center = hunk
576        .lines
577        .iter()
578        .position(|line| {
579            old_line.is_some() && line.old_line == old_line
580                || new_line.is_some() && line.new_line == new_line
581        })
582        .unwrap_or(0);
583    let half_window = max_lines / 2;
584    let mut start = center.saturating_sub(half_window);
585    let end = (start + max_lines).min(hunk.lines.len());
586    if end - start < max_lines && end == hunk.lines.len() {
587        start = end.saturating_sub(max_lines);
588    }
589
590    let mut out = String::new();
591    for line in &hunk.lines[start..end] {
592        out.push_str("    ");
593        out.push_str(&line.raw);
594        out.push('\n');
595    }
596    out
597}
598
599fn resolve_workspace_path(path: &str) -> Option<PathBuf> {
600    let trimmed = path.trim();
601    if trimmed.is_empty() {
602        return None;
603    }
604
605    let candidate = if Path::new(trimmed).is_absolute() {
606        PathBuf::from(trimmed)
607    } else {
608        std::env::current_dir().ok()?.join(trimmed)
609    };
610    if !candidate.is_file() {
611        return None;
612    }
613    Some(candidate)
614}
615
616fn file_line_snippet(path: &Path, line: u32) -> Option<String> {
617    if line == 0 {
618        return None;
619    }
620    let text = std::fs::read_to_string(path).ok()?;
621    let lines: Vec<&str> = text.lines().collect();
622    let target = usize::try_from(line.saturating_sub(1)).ok()?;
623    if target >= lines.len() {
624        return None;
625    }
626
627    let start = target.saturating_sub(2);
628    let end = (target + 3).min(lines.len());
629    let mut out = String::new();
630    for (idx, content) in lines[start..end].iter().enumerate() {
631        let absolute = start + idx + 1;
632        out.push_str(&format!("    {absolute:>5} | {content}\n"));
633    }
634    Some(out)
635}
636
637fn prompt_template(path: &str) -> &'static str {
638    AI_SESSION_PROMPTS_DIR
639        .get_file(path)
640        .unwrap_or_else(|| panic!("missing ai session prompt template: {path}"))
641        .contents_utf8()
642        .unwrap_or_else(|| panic!("invalid utf-8 in ai session prompt template: {path}"))
643}
644
645fn missing_comment_prompt(review_name: &str, comment_id: u64) -> String {
646    prompt_template("comment_not_found.md")
647        .replace("{review_name}", review_name)
648        .replace("{comment_id}", &comment_id.to_string())
649}
650
651async fn invoke_provider(
652    config: &AppConfig,
653    provider: AiProvider,
654    mode: AiSessionMode,
655    prompt: &str,
656    progress_sender: Option<mpsc::Sender<AiProgressEvent>>,
657) -> Result<ProviderInvocation> {
658    let provider_cfg = config.ai.provider_config(provider);
659    if provider_cfg.client.trim().is_empty() {
660        return Err(anyhow!(
661            "provider {} has no configured client in config.toml",
662            provider.as_str()
663        ));
664    }
665
666    let mut command = Command::new(&provider_cfg.client);
667    command.kill_on_drop(true);
668    let args = normalized_provider_args(provider, provider_cfg, mode);
669    command.args(&args);
670    let codex_output_path = codex_output_path(provider)?;
671    if let Some(path) = codex_output_path.as_ref() {
672        if !args.iter().any(|arg| arg == "--json") {
673            command.arg("--json");
674        }
675        command.arg("--output-last-message");
676        command.arg(path);
677    }
678    let configured_model = provider_cfg
679        .model
680        .as_deref()
681        .map(str::trim)
682        .filter(|value| !value.is_empty())
683        .map(str::to_string);
684    if let Some(model_value) = configured_model.as_deref() {
685        match provider_cfg.model_arg.as_deref().map(str::trim) {
686            Some(model_arg) if !model_arg.is_empty() => {
687                command.arg(model_arg);
688                command.arg(model_value);
689            }
690            _ => {
691                command.arg(model_value);
692            }
693        }
694    }
695    command.stdout(Stdio::piped()).stderr(Stdio::piped());
696
697    let prompt_transport = normalized_prompt_transport(provider, &provider_cfg.prompt_transport);
698    match prompt_transport {
699        PromptTransport::Stdin => {
700            command.stdin(Stdio::piped());
701        }
702        PromptTransport::Argv => {
703            command.arg(prompt);
704            command.stdin(Stdio::null());
705        }
706    }
707
708    let mut child = command
709        .spawn()
710        .with_context(|| format!("failed to start provider client '{}'", provider_cfg.client))?;
711    debug!(
712        provider = %provider.as_str(),
713        client = %provider_cfg.client,
714        prompt_chars = prompt.chars().count(),
715        "provider process spawned"
716    );
717    emit_progress(
718        progress_sender.as_ref(),
719        provider,
720        "system",
721        format!(
722            "spawned {} (mode={}, transport={})",
723            provider_cfg.client,
724            mode.as_str(),
725            match prompt_transport {
726                PromptTransport::Stdin => "stdin",
727                PromptTransport::Argv => "argv",
728            }
729        ),
730    );
731
732    if matches!(prompt_transport, PromptTransport::Stdin)
733        && let Some(mut stdin) = child.stdin.take()
734    {
735        stdin
736            .write_all(prompt.as_bytes())
737            .await
738            .context("failed to send prompt to provider stdin")?;
739        stdin.flush().await.ok();
740    }
741
742    let stdout_task = child.stdout.take().map(|stdout| {
743        tokio::spawn(read_stream(
744            stdout,
745            provider,
746            "stdout",
747            progress_sender.clone(),
748        ))
749    });
750    let stderr_task = child.stderr.take().map(|stderr| {
751        tokio::spawn(read_stream(
752            stderr,
753            provider,
754            "stderr",
755            progress_sender.clone(),
756        ))
757    });
758
759    let timeout_seconds = effective_timeout_seconds(config, mode);
760    let wait_result = timeout(Duration::from_secs(timeout_seconds), child.wait()).await;
761    let mut timed_out = false;
762    let status = match wait_result {
763        Ok(Ok(status)) => Some(status),
764        Ok(Err(error)) => return Err(anyhow!("provider process wait failed: {error}")),
765        Err(_) => {
766            timed_out = true;
767            let _ = child.kill().await;
768            None
769        }
770    };
771
772    let stdout = collect_stream_output(stdout_task).await;
773    let stderr = collect_stream_output(stderr_task).await;
774    let stderr_trimmed = stderr.trim().to_string();
775    let maybe_codex_reply = read_codex_output_last_message(codex_output_path.as_deref()).await?;
776
777    if timed_out {
778        let reply = maybe_codex_reply
779            .as_deref()
780            .unwrap_or(stdout.trim())
781            .trim()
782            .to_string();
783        if !reply.is_empty() {
784            warn!(
785                provider = %provider.as_str(),
786                mode = %mode.as_str(),
787                timeout_seconds,
788                "provider timed out but returned partial output"
789            );
790            emit_progress(
791                progress_sender.as_ref(),
792                provider,
793                "system",
794                format!("timeout after {timeout_seconds}s, returning partial output"),
795            );
796            return Ok(ProviderInvocation {
797                reply,
798                model: detect_runtime_model(provider, &stdout, &stderr)
799                    .or(configured_model.clone()),
800            });
801        }
802
803        emit_progress(
804            progress_sender.as_ref(),
805            provider,
806            "system",
807            format!("timeout after {timeout_seconds}s with no output"),
808        );
809        return Err(anyhow!(
810            "provider {} timed out after {}s{}",
811            provider.as_str(),
812            timeout_seconds,
813            if stderr_trimmed.is_empty() {
814                "".to_string()
815            } else {
816                format!(": {stderr_trimmed}")
817            }
818        ));
819    }
820    let status = status.expect("status is present when not timed out");
821
822    if !status.success() {
823        warn!(
824            provider = %provider.as_str(),
825            status = %status,
826            stderr = %stderr_trimmed,
827            "provider exited with non-zero status"
828        );
829        emit_progress(
830            progress_sender.as_ref(),
831            provider,
832            "system",
833            format!("provider exited with {status}: {stderr_trimmed}"),
834        );
835        return Err(anyhow!(
836            "provider exited with {}: {}",
837            status,
838            if stderr_trimmed.is_empty() {
839                "no stderr output".to_string()
840            } else {
841                stderr_trimmed
842            }
843        ));
844    }
845
846    let reply = maybe_codex_reply.unwrap_or_else(|| stdout.trim().to_string());
847    if reply.is_empty() {
848        warn!(provider = %provider.as_str(), "provider returned empty output");
849        emit_progress(
850            progress_sender.as_ref(),
851            provider,
852            "system",
853            "provider returned empty output",
854        );
855        return Err(anyhow!("provider returned empty output"));
856    }
857
858    emit_progress(
859        progress_sender.as_ref(),
860        provider,
861        "system",
862        "provider completed successfully",
863    );
864    Ok(ProviderInvocation {
865        reply,
866        model: detect_runtime_model(provider, &stdout, &stderr).or(configured_model),
867    })
868}
869
870fn format_ai_reply_body(model: Option<&str>, reply: &str) -> String {
871    let mut out = String::new();
872    if let Some(model) = model.map(str::trim).filter(|value| !value.is_empty()) {
873        out.push_str(&format!("Model: {model}\n\n"));
874    }
875    out.push_str(reply.trim_end());
876    out
877}
878
879fn detect_runtime_model(provider: AiProvider, stdout: &str, stderr: &str) -> Option<String> {
880    match provider {
881        AiProvider::Codex => detect_model_from_json_stream(stdout)
882            .or_else(|| detect_model_from_json_stream(stderr))
883            .or_else(|| detect_model_from_text(stdout))
884            .or_else(|| detect_model_from_text(stderr)),
885        AiProvider::Claude | AiProvider::Opencode => {
886            detect_model_from_text(stdout).or_else(|| detect_model_from_text(stderr))
887        }
888    }
889}
890
891fn detect_model_from_json_stream(stream: &str) -> Option<String> {
892    for line in stream.lines() {
893        let trimmed = line.trim();
894        if trimmed.is_empty() || !trimmed.starts_with('{') {
895            continue;
896        }
897        let Ok(value) = serde_json::from_str::<Value>(trimmed) else {
898            continue;
899        };
900        if let Some(model) = extract_model_from_json(&value) {
901            return Some(model);
902        }
903    }
904    None
905}
906
907fn extract_model_from_json(value: &Value) -> Option<String> {
908    match value {
909        Value::Object(map) => {
910            for key in [
911                "model",
912                "model_id",
913                "model_slug",
914                "resolved_model",
915                "selected_model",
916            ] {
917                if let Some(Value::String(found)) = map.get(key) {
918                    let trimmed = found.trim();
919                    if !trimmed.is_empty() {
920                        return Some(trimmed.to_string());
921                    }
922                }
923            }
924            for nested in map.values() {
925                if let Some(found) = extract_model_from_json(nested) {
926                    return Some(found);
927                }
928            }
929            None
930        }
931        Value::Array(items) => {
932            for item in items {
933                if let Some(found) = extract_model_from_json(item) {
934                    return Some(found);
935                }
936            }
937            None
938        }
939        _ => None,
940    }
941}
942
943fn detect_model_from_text(text: &str) -> Option<String> {
944    for line in text.lines() {
945        if let Some(value) = extract_model_after_marker(line, "model:") {
946            return Some(value);
947        }
948        if let Some(value) = extract_model_after_marker(line, "model=") {
949            return Some(value);
950        }
951    }
952    None
953}
954
955fn extract_model_after_marker(line: &str, marker: &str) -> Option<String> {
956    let (_, right) = line.split_once(marker)?;
957    let candidate = right.split_whitespace().next().map(|value| {
958        value.trim_matches(|ch: char| ch == '"' || ch == '\'' || ch == ',' || ch == ';')
959    })?;
960    if candidate.is_empty() {
961        None
962    } else {
963        Some(candidate.to_string())
964    }
965}
966
967fn normalized_provider_args(
968    provider: AiProvider,
969    provider_cfg: &crate::domain::config::AiProviderConfig,
970    mode: AiSessionMode,
971) -> Vec<String> {
972    let mut args = provider_cfg.args.clone();
973    match provider {
974        AiProvider::Codex => {
975            if !args.first().map(|value| value == "exec").unwrap_or(false) {
976                args.insert(0, "exec".to_string());
977            }
978            if !args.iter().any(|arg| arg == "--full-auto") {
979                args.push("--full-auto".to_string());
980            }
981            let has_sandbox_flag = args.iter().any(|arg| arg == "--sandbox" || arg == "-s");
982            if !has_sandbox_flag {
983                args.push("--sandbox".to_string());
984                args.push(match mode {
985                    AiSessionMode::Reply => "read-only".to_string(),
986                    AiSessionMode::Refactor => "workspace-write".to_string(),
987                });
988            }
989        }
990        AiProvider::Claude => {
991            if !args.iter().any(|arg| arg == "-p" || arg == "--print") {
992                args.insert(0, "-p".to_string());
993            }
994        }
995        AiProvider::Opencode => {
996            if !args.first().map(|value| value == "run").unwrap_or(false) {
997                args.insert(0, "run".to_string());
998            }
999        }
1000    }
1001    args
1002}
1003
1004fn codex_output_path(provider: AiProvider) -> Result<Option<std::path::PathBuf>> {
1005    if !matches!(provider, AiProvider::Codex) {
1006        return Ok(None);
1007    }
1008    let file = format!("parley-codex-last-{}-{}.txt", now_ms()?, std::process::id());
1009    Ok(Some(std::env::temp_dir().join(file)))
1010}
1011
1012async fn read_codex_output_last_message(path: Option<&std::path::Path>) -> Result<Option<String>> {
1013    let Some(path) = path else {
1014        return Ok(None);
1015    };
1016    let text = match fs::read_to_string(path).await {
1017        Ok(content) => content.trim().to_string(),
1018        Err(_) => String::new(),
1019    };
1020    let _ = fs::remove_file(path).await;
1021    if text.is_empty() {
1022        Ok(None)
1023    } else {
1024        Ok(Some(text))
1025    }
1026}
1027
1028async fn read_stream<R>(
1029    reader: R,
1030    provider: AiProvider,
1031    stream: &'static str,
1032    progress_sender: Option<mpsc::Sender<AiProgressEvent>>,
1033) -> String
1034where
1035    R: AsyncRead + Unpin + Send + 'static,
1036{
1037    let mut lines = BufReader::new(reader).lines();
1038    let mut out = String::new();
1039    while let Ok(Some(line)) = lines.next_line().await {
1040        info!(provider = %provider.as_str(), stream, payload = %line, "provider_stream");
1041        emit_progress(progress_sender.as_ref(), provider, stream, line.as_str());
1042        out.push_str(&line);
1043        out.push('\n');
1044    }
1045    out
1046}
1047
1048async fn collect_stream_output(task: Option<JoinHandle<String>>) -> String {
1049    let Some(task) = task else {
1050        return String::new();
1051    };
1052    match task.await {
1053        Ok(content) => content,
1054        Err(error) => format!("<stream task join failed: {error}>"),
1055    }
1056}
1057
1058fn normalized_prompt_transport(
1059    provider: AiProvider,
1060    configured: &PromptTransport,
1061) -> PromptTransport {
1062    let _ = configured;
1063    match provider {
1064        // Prefer explicit prompt argv for deterministic headless execution.
1065        AiProvider::Codex | AiProvider::Claude | AiProvider::Opencode => PromptTransport::Argv,
1066    }
1067}
1068
1069fn emit_progress(
1070    progress_sender: Option<&mpsc::Sender<AiProgressEvent>>,
1071    provider: AiProvider,
1072    stream: &str,
1073    message: impl Into<String>,
1074) {
1075    let Some(progress_sender) = progress_sender else {
1076        return;
1077    };
1078    let timestamp_ms = std::time::SystemTime::now()
1079        .duration_since(std::time::UNIX_EPOCH)
1080        .map(|elapsed| elapsed.as_millis() as u64)
1081        .unwrap_or(0);
1082    let _ = progress_sender.send(AiProgressEvent {
1083        timestamp_ms,
1084        provider: provider.as_str().to_string(),
1085        stream: stream.to_string(),
1086        message: message.into(),
1087    });
1088}
1089
1090fn comment_is_targetable(status: CommentStatus, mode: AiSessionMode) -> bool {
1091    match mode {
1092        AiSessionMode::Reply => {
1093            matches!(status, CommentStatus::Open | CommentStatus::Pending)
1094        }
1095        AiSessionMode::Refactor => matches!(status, CommentStatus::Open),
1096    }
1097}
1098
1099fn effective_timeout_seconds(config: &AppConfig, mode: AiSessionMode) -> u64 {
1100    let configured = config.ai.timeout_seconds.max(1);
1101    match mode {
1102        AiSessionMode::Reply => configured,
1103        // Refactor mode can involve tool execution and file edits; keep a higher floor.
1104        AiSessionMode::Refactor => configured.max(600),
1105    }
1106}
1107
1108fn now_ms() -> Result<u64> {
1109    let elapsed = std::time::SystemTime::now()
1110        .duration_since(std::time::UNIX_EPOCH)
1111        .context("system clock is before unix epoch")?;
1112    Ok(elapsed.as_millis() as u64)
1113}
1114
1115#[cfg(test)]
1116mod tests {
1117    use super::{
1118        choose_best_hunk, comment_is_targetable, detect_model_from_json_stream,
1119        detect_model_from_text, format_ai_reply_body, format_hunk_excerpt, hunk_distance_to_anchor,
1120    };
1121    use crate::domain::ai::AiSessionMode;
1122    use crate::domain::diff::{DiffFile, DiffHunk, DiffLine, DiffLineKind};
1123    use crate::domain::review::CommentStatus;
1124
1125    #[test]
1126    fn reply_mode_excludes_addressed_threads() {
1127        assert!(comment_is_targetable(
1128            CommentStatus::Open,
1129            AiSessionMode::Reply
1130        ));
1131        assert!(comment_is_targetable(
1132            CommentStatus::Pending,
1133            AiSessionMode::Reply
1134        ));
1135        assert!(!comment_is_targetable(
1136            CommentStatus::Addressed,
1137            AiSessionMode::Reply
1138        ));
1139    }
1140
1141    #[test]
1142    fn refactor_mode_targets_only_open_threads() {
1143        assert!(comment_is_targetable(
1144            CommentStatus::Open,
1145            AiSessionMode::Refactor
1146        ));
1147        assert!(!comment_is_targetable(
1148            CommentStatus::Pending,
1149            AiSessionMode::Refactor
1150        ));
1151        assert!(!comment_is_targetable(
1152            CommentStatus::Addressed,
1153            AiSessionMode::Refactor
1154        ));
1155    }
1156
1157    #[test]
1158    fn choose_best_hunk_prefers_exact_anchor_match() {
1159        let file = DiffFile {
1160            path: "src/lib.rs".to_string(),
1161            header_lines: Vec::new(),
1162            hunks: vec![
1163                make_hunk(
1164                    "@@ -1,3 +1,3 @@",
1165                    1,
1166                    1,
1167                    vec![line_ctx(1, 1), line_ctx(2, 2)],
1168                ),
1169                make_hunk(
1170                    "@@ -40,3 +40,3 @@",
1171                    40,
1172                    40,
1173                    vec![line_ctx(40, 40), line_ctx(41, 41)],
1174                ),
1175            ],
1176        };
1177
1178        let chosen = choose_best_hunk(&file, None, Some(41)).expect("hunk should be selected");
1179        assert_eq!(chosen.new_start, 40);
1180    }
1181
1182    #[test]
1183    fn choose_best_hunk_falls_back_to_nearest_start() {
1184        let file = DiffFile {
1185            path: "src/lib.rs".to_string(),
1186            header_lines: Vec::new(),
1187            hunks: vec![
1188                make_hunk("@@ -10,2 +10,2 @@", 10, 10, vec![line_ctx(10, 10)]),
1189                make_hunk("@@ -80,2 +80,2 @@", 80, 80, vec![line_ctx(80, 80)]),
1190            ],
1191        };
1192
1193        let chosen = choose_best_hunk(&file, None, Some(74)).expect("hunk should be selected");
1194        assert_eq!(chosen.new_start, 80);
1195        assert!(hunk_distance_to_anchor(chosen, None, Some(74)) < 10);
1196    }
1197
1198    #[test]
1199    fn hunk_excerpt_contains_anchor_line() {
1200        let hunk = make_hunk(
1201            "@@ -20,4 +20,4 @@",
1202            20,
1203            20,
1204            vec![
1205                line_ctx(20, 20),
1206                line_add(0, 21, "+let value = 1;"),
1207                line_ctx(22, 22),
1208            ],
1209        );
1210        let excerpt = format_hunk_excerpt(&hunk, None, Some(21), 8);
1211        assert!(excerpt.contains("+let value = 1;"));
1212        assert!(excerpt.contains("@@ -20,4 +20,4 @@"));
1213    }
1214
1215    #[test]
1216    fn ai_reply_body_includes_model_header() {
1217        let body = format_ai_reply_body(Some("gpt-5.4"), "Implemented fix.");
1218        assert!(body.starts_with("Model: gpt-5.4"));
1219        assert!(body.contains("Implemented fix."));
1220    }
1221
1222    #[test]
1223    fn ai_reply_body_omits_header_when_model_unknown() {
1224        let body = format_ai_reply_body(None, "Implemented fix.");
1225        assert_eq!(body, "Implemented fix.");
1226    }
1227
1228    #[test]
1229    fn detect_model_from_json_stream_reads_nested_model_slug() {
1230        let stream = r#"{"event":"meta","payload":{"session":{"model_slug":"gpt-5.4"}}}"#;
1231        let detected = detect_model_from_json_stream(stream).expect("model should be detected");
1232        assert_eq!(detected, "gpt-5.4");
1233    }
1234
1235    #[test]
1236    fn detect_model_from_text_reads_model_marker() {
1237        let detected =
1238            detect_model_from_text("run complete; model=gpt-5.4; tokens=100").expect("model");
1239        assert_eq!(detected, "gpt-5.4");
1240    }
1241
1242    fn make_hunk(
1243        header: &str,
1244        old_start: u32,
1245        new_start: u32,
1246        mut extra: Vec<DiffLine>,
1247    ) -> DiffHunk {
1248        let mut lines = vec![DiffLine {
1249            kind: DiffLineKind::HunkHeader,
1250            old_line: None,
1251            new_line: None,
1252            raw: header.to_string(),
1253            code: header.to_string(),
1254        }];
1255        lines.append(&mut extra);
1256        DiffHunk {
1257            old_start,
1258            old_count: 1,
1259            new_start,
1260            new_count: 1,
1261            header: header.to_string(),
1262            lines,
1263        }
1264    }
1265
1266    fn line_ctx(old: u32, new: u32) -> DiffLine {
1267        DiffLine {
1268            kind: DiffLineKind::Context,
1269            old_line: Some(old),
1270            new_line: Some(new),
1271            raw: format!(" context {old}:{new}"),
1272            code: format!("context {old}:{new}"),
1273        }
1274    }
1275
1276    fn line_add(old: u32, new: u32, raw: &str) -> DiffLine {
1277        DiffLine {
1278            kind: DiffLineKind::Added,
1279            old_line: if old == 0 { None } else { Some(old) },
1280            new_line: Some(new),
1281            raw: raw.to_string(),
1282            code: raw.trim_start_matches('+').to_string(),
1283        }
1284    }
1285}