Skip to main content

parley/services/
ai_session.rs

1use self::progress::emit_progress;
2use self::prompt::{build_thread_prompt, load_task_prompt_override};
3use self::provider::{format_ai_reply_body, invoke_provider};
4use crate::domain::ai::{AiProvider, AiSessionMode};
5use crate::domain::config::{AgentTransport, AiProviderConfig, AppConfig};
6use crate::domain::diff::DiffDocument;
7use crate::domain::review::{Author, CommentStatus, ReviewSession};
8use crate::git::diff::{DiffSource, load_git_diff};
9use crate::services::review_service::{AddReplyInput, ReviewService};
10use crate::utils::time::now_ms;
11use anyhow::{Result, anyhow};
12use serde::{Deserialize, Serialize};
13use tokio::sync::mpsc;
14use tracing::{debug, error, info, warn};
15
16pub(crate) mod json_text;
17mod progress;
18mod prompt;
19mod provider;
20
21#[cfg(test)]
22mod tests;
23
24use std::path::PathBuf;
25
26#[derive(Debug, Clone)]
27pub struct RunAiSessionInput {
28    pub review_name: String,
29    pub provider: AiProvider,
30    pub transport: Option<AgentTransport>,
31    pub comment_ids: Vec<u64>,
32    pub mode: AiSessionMode,
33    pub diff_source: DiffSource,
34    pub worktree_path: Option<PathBuf>,
35}
36
37#[derive(Debug, Clone, Serialize)]
38#[serde(rename_all = "snake_case")]
39pub struct AiSessionResult {
40    pub review_name: String,
41    pub provider: String,
42    pub mode: String,
43    pub transport: String,
44    pub client: String,
45    pub model: Option<String>,
46    pub session_id: String,
47    pub processed: usize,
48    pub skipped: usize,
49    pub failed: usize,
50    pub items: Vec<AiSessionItemResult>,
51}
52
53#[derive(Debug, Clone, Serialize)]
54#[serde(rename_all = "snake_case")]
55pub struct AiSessionItemResult {
56    pub comment_id: u64,
57    pub status: String,
58    pub message: String,
59}
60
61impl AiSessionResult {
62    fn new(input: &RunAiSessionInput, provider_cfg: &AiProviderConfig, now_ms: u64) -> Self {
63        Self {
64            review_name: input.review_name.clone(),
65            provider: input.provider.as_str().to_string(),
66            mode: input.mode.as_str().to_string(),
67            transport: provider_cfg.transport.as_str().to_string(),
68            client: provider_cfg.client.clone(),
69            model: provider_cfg.model.clone(),
70            session_id: format!("{}-{}-{now_ms}", input.review_name, input.provider.as_str()),
71            processed: 0,
72            skipped: 0,
73            failed: 0,
74            items: Vec::new(),
75        }
76    }
77
78    fn push_processed(&mut self, comment_id: u64, message: impl Into<String>) {
79        self.processed += 1;
80        self.push_item(comment_id, "processed", message);
81    }
82
83    fn push_skipped(&mut self, comment_id: u64, message: impl Into<String>) {
84        self.skipped += 1;
85        self.push_item(comment_id, "skipped", message);
86    }
87
88    fn push_failed(&mut self, comment_id: u64, message: impl Into<String>) {
89        self.failed += 1;
90        self.push_item(comment_id, "failed", message);
91    }
92
93    fn push_item(&mut self, comment_id: u64, status: &str, message: impl Into<String>) {
94        self.items.push(AiSessionItemResult {
95            comment_id,
96            status: status.to_string(),
97            message: message.into(),
98        });
99    }
100}
101
102#[derive(Debug, Clone, Serialize)]
103#[serde(rename_all = "snake_case")]
104pub struct AiProgressEvent {
105    pub timestamp_ms: u64,
106    pub provider: String,
107    pub stream: String,
108    pub message: String,
109}
110
111#[must_use]
112pub fn default_ai_session_mode(comment_ids: &[u64]) -> AiSessionMode {
113    if comment_ids.is_empty() {
114        AiSessionMode::Refactor
115    } else {
116        AiSessionMode::Reply
117    }
118}
119
120/// # Errors
121///
122/// Returns an error when the review/config cannot be loaded, the clock is invalid, provider
123/// invocation fails, or review updates cannot be persisted.
124pub async fn run_ai_session(
125    service: &ReviewService,
126    input: RunAiSessionInput,
127) -> Result<AiSessionResult> {
128    run_ai_session_inner(service, input, None).await
129}
130
131/// # Errors
132///
133/// Returns an error for the same load, provider, clock, and persistence failures as
134/// [`run_ai_session`].
135pub async fn run_ai_session_with_progress(
136    service: &ReviewService,
137    input: RunAiSessionInput,
138    progress_sender: mpsc::UnboundedSender<AiProgressEvent>,
139) -> Result<AiSessionResult> {
140    run_ai_session_inner(service, input, Some(progress_sender)).await
141}
142
143async fn run_ai_session_inner(
144    service: &ReviewService,
145    input: RunAiSessionInput,
146    progress_sender: Option<mpsc::UnboundedSender<AiProgressEvent>>,
147) -> Result<AiSessionResult> {
148    info!(
149        review = %input.review_name,
150        provider = %input.provider.as_str(),
151        requested_comments = input.comment_ids.len(),
152        "starting ai session"
153    );
154    let config = service.load_config().await?;
155    let mut review = service.load_review(&input.review_name).await?;
156    let worktree = input
157        .worktree_path
158        .as_deref()
159        .unwrap_or_else(|| std::path::Path::new("."));
160    let diff_document = match load_git_diff(&config, &input.diff_source, worktree).await {
161        Ok(document) => Some(document),
162        Err(error) => {
163            warn!(error = %error, "ai session prompt context: unable to load git diff");
164            None
165        }
166    };
167    let now_ms = now_ms()?;
168    let effective_transport = input.transport.or(config.ai.default_transport);
169    let provider_cfg = config
170        .ai
171        .provider_config_for_transport(input.provider, effective_transport);
172    let mut result = AiSessionResult::new(&input, &provider_cfg, now_ms);
173
174    let target_ids = ai_session_target_ids(&review, &input.comment_ids);
175    let total_targets = target_ids.len();
176    if total_targets == 0 {
177        result.push_skipped(0, no_targets_message(input.mode));
178        emit_progress(
179            progress_sender.as_ref(),
180            input.provider,
181            "system",
182            "no open threads to process",
183        );
184        return Ok(result);
185    }
186
187    let task_prompt_override = load_task_prompt_override(&config, input.mode).await?;
188    let context = AiSessionExecutionContext {
189        service,
190        config: &config,
191        input: &input,
192        diff_document: diff_document.as_ref(),
193        task_prompt_override: task_prompt_override.as_deref(),
194        progress_sender,
195    };
196    process_ai_session_targets(&context, &mut review, &mut result, target_ids).await?;
197
198    info!(
199        review = %input.review_name,
200        provider = %input.provider.as_str(),
201        processed = result.processed,
202        skipped = result.skipped,
203        failed = result.failed,
204        "ai session completed"
205    );
206    Ok(result)
207}
208
209struct AiSessionExecutionContext<'a> {
210    service: &'a ReviewService,
211    config: &'a AppConfig,
212    input: &'a RunAiSessionInput,
213    diff_document: Option<&'a DiffDocument>,
214    task_prompt_override: Option<&'a str>,
215    progress_sender: Option<mpsc::UnboundedSender<AiProgressEvent>>,
216}
217
218async fn process_ai_session_targets(
219    context: &AiSessionExecutionContext<'_>,
220    review: &mut ReviewSession,
221    result: &mut AiSessionResult,
222    target_ids: Vec<u64>,
223) -> Result<()> {
224    let total_targets = target_ids.len();
225    for (step_index, comment_id) in target_ids.into_iter().enumerate() {
226        let step_number = step_index + 1;
227        emit_progress(
228            context.progress_sender.as_ref(),
229            context.input.provider,
230            "system",
231            format!("thread #{comment_id}: start ({step_number}/{total_targets})"),
232        );
233        debug!(
234            review = %context.input.review_name,
235            provider = %context.input.provider.as_str(),
236            comment_id,
237            "processing ai thread"
238        );
239        process_ai_session_target(
240            context,
241            review,
242            result,
243            comment_id,
244            step_number,
245            total_targets,
246        )
247        .await?;
248    }
249
250    Ok(())
251}
252
253async fn process_ai_session_target(
254    context: &AiSessionExecutionContext<'_>,
255    review: &mut ReviewSession,
256    result: &mut AiSessionResult,
257    comment_id: u64,
258    step_number: usize,
259    total_targets: usize,
260) -> Result<()> {
261    let opt_status = comment_status(review, comment_id);
262    if opt_status.is_none() {
263        warn!(
264            review = %context.input.review_name,
265            provider = %context.input.provider.as_str(),
266            comment_id,
267            "ai session target comment not found"
268        );
269        result.push_failed(comment_id, "comment not found in review");
270        emit_progress(
271            context.progress_sender.as_ref(),
272            context.input.provider,
273            "system",
274            format!("thread #{comment_id}: failed (comment not found)"),
275        );
276        return Ok(());
277    }
278    let comment_status = opt_status.unwrap();
279
280    if !comment_is_targetable(comment_status) {
281        debug!(
282            review = %context.input.review_name,
283            provider = %context.input.provider.as_str(),
284            comment_id,
285            status = ?comment_status,
286            "skipping non-targetable comment for selected mode"
287        );
288        result.push_skipped(
289            comment_id,
290            format!(
291                "comment status {:?} is not targetable for {} mode",
292                comment_status,
293                context.input.mode.as_str()
294            ),
295        );
296        emit_progress(
297            context.progress_sender.as_ref(),
298            context.input.provider,
299            "system",
300            format!("thread #{comment_id}: skipped (status={comment_status:?})"),
301        );
302        return Ok(());
303    }
304
305    let prompt = build_thread_prompt(
306        &context.input.review_name,
307        comment_id,
308        review,
309        context.diff_document,
310        context.input.mode,
311        context.task_prompt_override,
312    )
313    .await?;
314    let provider_reply = match invoke_provider(
315        context.config,
316        context.input.provider,
317        context.input.transport,
318        context.input.mode,
319        &prompt,
320        context.progress_sender.clone(),
321        context.input.worktree_path.as_deref(),
322    )
323    .await
324    {
325        Ok(reply) => reply,
326        Err(error) => {
327            error!(
328                review = %context.input.review_name,
329                provider = %context.input.provider.as_str(),
330                comment_id,
331                error = %error,
332                "provider invocation failed"
333            );
334            result.push_failed(comment_id, format!("provider failed: {error}"));
335            emit_progress(
336                context.progress_sender.as_ref(),
337                context.input.provider,
338                "system",
339                format!("thread #{comment_id}: failed ({error})"),
340            );
341            return Ok(());
342        }
343    };
344    let parsed_reply = match parse_ai_thread_reply_json(&provider_reply.reply, comment_id) {
345        Ok(parsed_reply) => parsed_reply,
346        Err(error) => {
347            result.push_failed(comment_id, format!("invalid AI reply JSON: {error}"));
348            emit_progress(
349                context.progress_sender.as_ref(),
350                context.input.provider,
351                "system",
352                format!("thread #{comment_id}: failed (invalid AI reply JSON: {error})"),
353            );
354            return Ok(());
355        }
356    };
357    let reply_body = format_ai_reply_body(provider_reply.model.as_deref(), &parsed_reply.reply);
358
359    *review = match context
360        .service
361        .add_reply(
362            &context.input.review_name,
363            AddReplyInput {
364                comment_id: parsed_reply.thread_id,
365                author: Author::Ai,
366                body: reply_body,
367            },
368        )
369        .await
370    {
371        Ok(updated) => updated,
372        Err(error) => {
373            error!(
374                review = %context.input.review_name,
375                provider = %context.input.provider.as_str(),
376                comment_id,
377                error = %error,
378                "failed to persist ai reply"
379            );
380            result.push_failed(comment_id, format!("failed to persist ai reply: {error}"));
381            emit_progress(
382                context.progress_sender.as_ref(),
383                context.input.provider,
384                "system",
385                format!("thread #{comment_id}: failed (persist reply: {error})"),
386            );
387            return Ok(());
388        }
389    };
390
391    info!(
392        review = %context.input.review_name,
393        provider = %context.input.provider.as_str(),
394        comment_id,
395        "ai reply persisted"
396    );
397    result.push_processed(comment_id, processed_target_message(context.input.mode));
398    emit_progress(
399        context.progress_sender.as_ref(),
400        context.input.provider,
401        "system",
402        format!(
403            "thread #{comment_id}: reply persisted; status pending_human ({step_number}/{total_targets})"
404        ),
405    );
406    Ok(())
407}
408
409fn ai_session_target_ids(review: &ReviewSession, comment_ids: &[u64]) -> Vec<u64> {
410    if !comment_ids.is_empty() {
411        return comment_ids.to_vec();
412    }
413
414    review
415        .comments
416        .iter()
417        .filter(|comment| comment_is_targetable(&comment.status))
418        .map(|comment| comment.id)
419        .collect()
420}
421
422fn comment_status(review: &ReviewSession, comment_id: u64) -> Option<&CommentStatus> {
423    review.comments.iter().find_map(|comment| {
424        if comment.id == comment_id {
425            Some(&comment.status)
426        } else {
427            None
428        }
429    })
430}
431
432fn no_targets_message(mode: AiSessionMode) -> &'static str {
433    match mode {
434        AiSessionMode::Reply => "no replyable threads to process",
435        AiSessionMode::Refactor => "no open threads to process",
436    }
437}
438
439fn processed_target_message(mode: AiSessionMode) -> &'static str {
440    match mode {
441        AiSessionMode::Reply => "ai reply added",
442        AiSessionMode::Refactor => "ai reply added; thread status moved to pending_human",
443    }
444}
445
446#[derive(Debug)]
447struct ParsedAiThreadReply {
448    thread_id: u64,
449    reply: String,
450}
451
452#[derive(Debug, Deserialize)]
453#[serde(deny_unknown_fields)]
454struct AiThreadReplyJson {
455    thread_id: u64,
456    reply: String,
457    status: String,
458}
459
460fn parse_ai_thread_reply_json(
461    raw_reply: &str,
462    expected_thread_id: u64,
463) -> Result<ParsedAiThreadReply> {
464    let json = strip_json_code_fence(raw_reply).trim();
465    let parsed: AiThreadReplyJson = match serde_json::from_str(json) {
466        Ok(parsed) => parsed,
467        Err(error) => {
468            let Some(candidate) = embedded_ai_reply_json_candidate(json) else {
469                return Err(invalid_ai_reply_json_error(error, json));
470            };
471            serde_json::from_str(candidate)
472                .map_err(|error| invalid_ai_reply_json_error(error, candidate))?
473        }
474    };
475
476    if parsed.thread_id != expected_thread_id {
477        return Err(anyhow!(
478            "thread_id {} did not match requested thread {}",
479            parsed.thread_id,
480            expected_thread_id
481        ));
482    }
483
484    if parsed.status != "pending_human" {
485        return Err(anyhow!(
486            "status {:?} did not match required pending_human",
487            parsed.status
488        ));
489    }
490
491    let reply = parsed.reply.trim().to_string();
492    if reply.is_empty() {
493        return Err(anyhow!("reply must not be empty"));
494    }
495
496    Ok(ParsedAiThreadReply {
497        thread_id: parsed.thread_id,
498        reply,
499    })
500}
501
502fn invalid_ai_reply_json_error(error: serde_json::Error, json: &str) -> anyhow::Error {
503    let trimmed = json.trim();
504    if trimmed.is_empty() {
505        return anyhow!(
506            "expected JSON object with thread_id, reply, status: {error}; response was empty"
507        );
508    }
509
510    anyhow!(
511        "expected JSON object with thread_id, reply, status: {error}; response preview: {}",
512        ai_reply_preview(trimmed)
513    )
514}
515
516fn ai_reply_preview(value: &str) -> String {
517    const MAX_PREVIEW_CHARS: usize = 500;
518    let mut preview = value
519        .chars()
520        .take(MAX_PREVIEW_CHARS)
521        .collect::<String>()
522        .replace('\r', "\\r")
523        .replace('\n', "\\n")
524        .replace('\t', "\\t");
525    if value.chars().count() > MAX_PREVIEW_CHARS {
526        preview.push_str("...");
527    }
528    preview
529}
530
531fn strip_json_code_fence(raw_reply: &str) -> &str {
532    let trimmed = raw_reply.trim();
533    if !trimmed.starts_with("```") {
534        return trimmed;
535    }
536
537    let without_start = if let Some(value) = trimmed.strip_prefix("```json") {
538        value
539    } else if let Some(value) = trimmed.strip_prefix("```") {
540        value
541    } else {
542        trimmed
543    };
544
545    let without_start = without_start.trim_start();
546    if let Some(value) = without_start.strip_suffix("```") {
547        value.trim()
548    } else {
549        without_start
550    }
551}
552
553fn embedded_ai_reply_json_candidate(value: &str) -> Option<&str> {
554    let mut search_start = 0;
555    while search_start < value.len() {
556        let start = value.get(search_start..)?.find('{')? + search_start;
557        let end = balanced_json_object_end(value, start)?;
558        let candidate = &value[start..end];
559        if has_ai_reply_schema_keys(candidate) {
560            return Some(candidate);
561        }
562        search_start = end;
563    }
564    None
565}
566
567fn has_ai_reply_schema_keys(candidate: &str) -> bool {
568    let Ok(value) = serde_json::from_str::<serde_json::Value>(candidate) else {
569        return false;
570    };
571    let Some(object) = value.as_object() else {
572        return false;
573    };
574    object.contains_key("thread_id")
575        && object.contains_key("reply")
576        && object.contains_key("status")
577}
578
579fn balanced_json_object_end(value: &str, start: usize) -> Option<usize> {
580    let mut depth = 0usize;
581    let mut in_string = false;
582    let mut escaped = false;
583
584    for (offset, ch) in value.get(start..)?.char_indices() {
585        if in_string {
586            if escaped {
587                escaped = false;
588            } else if ch == '\\' {
589                escaped = true;
590            } else if ch == '"' {
591                in_string = false;
592            }
593            continue;
594        }
595
596        match ch {
597            '"' => in_string = true,
598            '{' => depth = depth.saturating_add(1),
599            '}' => {
600                depth = depth.checked_sub(1)?;
601                if depth == 0 {
602                    return Some(start + offset + ch.len_utf8());
603                }
604            }
605            _ => {}
606        }
607    }
608
609    None
610}
611
612fn comment_is_targetable(status: &CommentStatus) -> bool {
613    matches!(status, CommentStatus::Open | CommentStatus::Pending)
614}