Skip to main content

parley/services/
ai_session.rs

1use self::progress::emit_progress;
2use self::prompt::{build_thread_prompt, load_task_prompt_override};
3use self::provider::{format_ai_reply_body, invoke_provider};
4use crate::domain::ai::{AiProvider, AiSessionMode};
5use crate::domain::config::{AgentTransport, AiProviderConfig, AppConfig};
6use crate::domain::diff::DiffDocument;
7use crate::domain::review::{Author, CommentStatus, ReviewSession};
8use crate::git::diff::{DiffSource, load_git_diff};
9use crate::services::review_service::{AddReplyInput, ReviewService};
10use crate::utils::time::now_ms;
11use anyhow::{Result, anyhow};
12use serde::{Deserialize, Serialize};
13use tokio::sync::mpsc;
14use tracing::{debug, error, info, warn};
15
16pub(crate) mod json_text;
17mod progress;
18mod prompt;
19mod provider;
20
21#[cfg(test)]
22mod tests;
23
24#[derive(Debug, Clone)]
25pub struct RunAiSessionInput {
26    pub review_name: String,
27    pub provider: AiProvider,
28    pub transport: Option<AgentTransport>,
29    pub comment_ids: Vec<u64>,
30    pub mode: AiSessionMode,
31    pub diff_source: DiffSource,
32}
33
34#[derive(Debug, Clone, Serialize)]
35#[serde(rename_all = "snake_case")]
36pub struct AiSessionResult {
37    pub review_name: String,
38    pub provider: String,
39    pub mode: String,
40    pub transport: String,
41    pub client: String,
42    pub model: Option<String>,
43    pub session_id: String,
44    pub processed: usize,
45    pub skipped: usize,
46    pub failed: usize,
47    pub items: Vec<AiSessionItemResult>,
48}
49
50#[derive(Debug, Clone, Serialize)]
51#[serde(rename_all = "snake_case")]
52pub struct AiSessionItemResult {
53    pub comment_id: u64,
54    pub status: String,
55    pub message: String,
56}
57
58impl AiSessionResult {
59    fn new(input: &RunAiSessionInput, provider_cfg: &AiProviderConfig, now_ms: u64) -> Self {
60        Self {
61            review_name: input.review_name.clone(),
62            provider: input.provider.as_str().to_string(),
63            mode: input.mode.as_str().to_string(),
64            transport: provider_cfg.transport.as_str().to_string(),
65            client: provider_cfg.client.clone(),
66            model: provider_cfg.model.clone(),
67            session_id: format!("{}-{}-{now_ms}", input.review_name, input.provider.as_str()),
68            processed: 0,
69            skipped: 0,
70            failed: 0,
71            items: Vec::new(),
72        }
73    }
74
75    fn push_processed(&mut self, comment_id: u64, message: impl Into<String>) {
76        self.processed += 1;
77        self.push_item(comment_id, "processed", message);
78    }
79
80    fn push_skipped(&mut self, comment_id: u64, message: impl Into<String>) {
81        self.skipped += 1;
82        self.push_item(comment_id, "skipped", message);
83    }
84
85    fn push_failed(&mut self, comment_id: u64, message: impl Into<String>) {
86        self.failed += 1;
87        self.push_item(comment_id, "failed", message);
88    }
89
90    fn push_item(&mut self, comment_id: u64, status: &str, message: impl Into<String>) {
91        self.items.push(AiSessionItemResult {
92            comment_id,
93            status: status.to_string(),
94            message: message.into(),
95        });
96    }
97}
98
99#[derive(Debug, Clone, Serialize)]
100#[serde(rename_all = "snake_case")]
101pub struct AiProgressEvent {
102    pub timestamp_ms: u64,
103    pub provider: String,
104    pub stream: String,
105    pub message: String,
106}
107
108#[must_use]
109pub fn default_ai_session_mode(comment_ids: &[u64]) -> AiSessionMode {
110    if comment_ids.is_empty() {
111        AiSessionMode::Refactor
112    } else {
113        AiSessionMode::Reply
114    }
115}
116
117/// # Errors
118///
119/// Returns an error when the review/config cannot be loaded, the clock is invalid, provider
120/// invocation fails, or review updates cannot be persisted.
121pub async fn run_ai_session(
122    service: &ReviewService,
123    input: RunAiSessionInput,
124) -> Result<AiSessionResult> {
125    run_ai_session_inner(service, input, None).await
126}
127
128/// # Errors
129///
130/// Returns an error for the same load, provider, clock, and persistence failures as
131/// [`run_ai_session`].
132pub async fn run_ai_session_with_progress(
133    service: &ReviewService,
134    input: RunAiSessionInput,
135    progress_sender: mpsc::UnboundedSender<AiProgressEvent>,
136) -> Result<AiSessionResult> {
137    run_ai_session_inner(service, input, Some(progress_sender)).await
138}
139
140async fn run_ai_session_inner(
141    service: &ReviewService,
142    input: RunAiSessionInput,
143    progress_sender: Option<mpsc::UnboundedSender<AiProgressEvent>>,
144) -> Result<AiSessionResult> {
145    info!(
146        review = %input.review_name,
147        provider = %input.provider.as_str(),
148        requested_comments = input.comment_ids.len(),
149        "starting ai session"
150    );
151    let config = service.load_config().await?;
152    let mut review = service.load_review(&input.review_name).await?;
153    let diff_document = match load_git_diff(&config, &input.diff_source).await {
154        Ok(document) => Some(document),
155        Err(error) => {
156            warn!(error = %error, "ai session prompt context: unable to load git diff");
157            None
158        }
159    };
160    let now_ms = now_ms()?;
161    let effective_transport = input.transport.or(config.ai.default_transport);
162    let provider_cfg = config
163        .ai
164        .provider_config_for_transport(input.provider, effective_transport);
165    let mut result = AiSessionResult::new(&input, &provider_cfg, now_ms);
166
167    let target_ids = ai_session_target_ids(&review, &input.comment_ids);
168    let total_targets = target_ids.len();
169    if total_targets == 0 {
170        result.push_skipped(0, no_targets_message(input.mode));
171        emit_progress(
172            progress_sender.as_ref(),
173            input.provider,
174            "system",
175            "no open threads to process",
176        );
177        return Ok(result);
178    }
179
180    let task_prompt_override = load_task_prompt_override(&config, input.mode).await?;
181    let context = AiSessionExecutionContext {
182        service,
183        config: &config,
184        input: &input,
185        diff_document: diff_document.as_ref(),
186        task_prompt_override: task_prompt_override.as_deref(),
187        progress_sender,
188    };
189    process_ai_session_targets(&context, &mut review, &mut result, target_ids).await?;
190
191    info!(
192        review = %input.review_name,
193        provider = %input.provider.as_str(),
194        processed = result.processed,
195        skipped = result.skipped,
196        failed = result.failed,
197        "ai session completed"
198    );
199    Ok(result)
200}
201
202struct AiSessionExecutionContext<'a> {
203    service: &'a ReviewService,
204    config: &'a AppConfig,
205    input: &'a RunAiSessionInput,
206    diff_document: Option<&'a DiffDocument>,
207    task_prompt_override: Option<&'a str>,
208    progress_sender: Option<mpsc::UnboundedSender<AiProgressEvent>>,
209}
210
211async fn process_ai_session_targets(
212    context: &AiSessionExecutionContext<'_>,
213    review: &mut ReviewSession,
214    result: &mut AiSessionResult,
215    target_ids: Vec<u64>,
216) -> Result<()> {
217    let total_targets = target_ids.len();
218    for (step_index, comment_id) in target_ids.into_iter().enumerate() {
219        let step_number = step_index + 1;
220        emit_progress(
221            context.progress_sender.as_ref(),
222            context.input.provider,
223            "system",
224            format!(
225                "thread #{comment_id}: start ({}/{})",
226                step_number, total_targets
227            ),
228        );
229        debug!(
230            review = %context.input.review_name,
231            provider = %context.input.provider.as_str(),
232            comment_id,
233            "processing ai thread"
234        );
235        process_ai_session_target(
236            context,
237            review,
238            result,
239            comment_id,
240            step_number,
241            total_targets,
242        )
243        .await?;
244    }
245
246    Ok(())
247}
248
249async fn process_ai_session_target(
250    context: &AiSessionExecutionContext<'_>,
251    review: &mut ReviewSession,
252    result: &mut AiSessionResult,
253    comment_id: u64,
254    step_number: usize,
255    total_targets: usize,
256) -> Result<()> {
257    let Some(comment_status) = comment_status(review, comment_id) else {
258        warn!(
259            review = %context.input.review_name,
260            provider = %context.input.provider.as_str(),
261            comment_id,
262            "ai session target comment not found"
263        );
264        result.push_failed(comment_id, "comment not found in review");
265        emit_progress(
266            context.progress_sender.as_ref(),
267            context.input.provider,
268            "system",
269            format!("thread #{comment_id}: failed (comment not found)"),
270        );
271        return Ok(());
272    };
273
274    if !comment_is_targetable(comment_status.clone()) {
275        debug!(
276            review = %context.input.review_name,
277            provider = %context.input.provider.as_str(),
278            comment_id,
279            status = ?comment_status,
280            "skipping non-targetable comment for selected mode"
281        );
282        result.push_skipped(
283            comment_id,
284            format!(
285                "comment status {:?} is not targetable for {} mode",
286                comment_status,
287                context.input.mode.as_str()
288            ),
289        );
290        emit_progress(
291            context.progress_sender.as_ref(),
292            context.input.provider,
293            "system",
294            format!("thread #{comment_id}: skipped (status={comment_status:?})"),
295        );
296        return Ok(());
297    }
298
299    let prompt = build_thread_prompt(
300        &context.input.review_name,
301        comment_id,
302        review,
303        context.diff_document,
304        context.input.mode,
305        context.task_prompt_override,
306    )
307    .await?;
308    let provider_reply = match invoke_provider(
309        context.config,
310        context.input.provider,
311        context.input.transport,
312        context.input.mode,
313        &prompt,
314        context.progress_sender.clone(),
315    )
316    .await
317    {
318        Ok(reply) => reply,
319        Err(error) => {
320            error!(
321                review = %context.input.review_name,
322                provider = %context.input.provider.as_str(),
323                comment_id,
324                error = %error,
325                "provider invocation failed"
326            );
327            result.push_failed(comment_id, format!("provider failed: {error}"));
328            emit_progress(
329                context.progress_sender.as_ref(),
330                context.input.provider,
331                "system",
332                format!("thread #{comment_id}: failed ({error})"),
333            );
334            return Ok(());
335        }
336    };
337    let parsed_reply = match parse_ai_thread_reply_json(&provider_reply.reply, comment_id) {
338        Ok(parsed_reply) => parsed_reply,
339        Err(error) => {
340            result.push_failed(comment_id, format!("invalid AI reply JSON: {error}"));
341            emit_progress(
342                context.progress_sender.as_ref(),
343                context.input.provider,
344                "system",
345                format!("thread #{comment_id}: failed (invalid AI reply JSON: {error})"),
346            );
347            return Ok(());
348        }
349    };
350    let reply_body = format_ai_reply_body(provider_reply.model.as_deref(), &parsed_reply.reply);
351
352    *review = match context
353        .service
354        .add_reply(
355            &context.input.review_name,
356            AddReplyInput {
357                comment_id: parsed_reply.thread_id,
358                author: Author::Ai,
359                body: reply_body,
360            },
361        )
362        .await
363    {
364        Ok(updated) => updated,
365        Err(error) => {
366            error!(
367                review = %context.input.review_name,
368                provider = %context.input.provider.as_str(),
369                comment_id,
370                error = %error,
371                "failed to persist ai reply"
372            );
373            result.push_failed(comment_id, format!("failed to persist ai reply: {error}"));
374            emit_progress(
375                context.progress_sender.as_ref(),
376                context.input.provider,
377                "system",
378                format!("thread #{comment_id}: failed (persist reply: {error})"),
379            );
380            return Ok(());
381        }
382    };
383
384    info!(
385        review = %context.input.review_name,
386        provider = %context.input.provider.as_str(),
387        comment_id,
388        "ai reply persisted"
389    );
390    result.push_processed(comment_id, processed_target_message(context.input.mode));
391    emit_progress(
392        context.progress_sender.as_ref(),
393        context.input.provider,
394        "system",
395        format!(
396            "thread #{comment_id}: reply persisted; status pending_human ({step_number}/{total_targets})"
397        ),
398    );
399    Ok(())
400}
401
402fn ai_session_target_ids(review: &ReviewSession, comment_ids: &[u64]) -> Vec<u64> {
403    if !comment_ids.is_empty() {
404        return comment_ids.to_vec();
405    }
406
407    review
408        .comments
409        .iter()
410        .filter(|comment| comment_is_targetable(comment.status.clone()))
411        .map(|comment| comment.id)
412        .collect()
413}
414
415fn comment_status(review: &ReviewSession, comment_id: u64) -> Option<CommentStatus> {
416    review
417        .comments
418        .iter()
419        .find(|comment| comment.id == comment_id)
420        .map(|comment| comment.status.clone())
421}
422
423fn no_targets_message(mode: AiSessionMode) -> &'static str {
424    match mode {
425        AiSessionMode::Reply => "no replyable threads to process",
426        AiSessionMode::Refactor => "no open threads to process",
427    }
428}
429
430fn processed_target_message(mode: AiSessionMode) -> &'static str {
431    match mode {
432        AiSessionMode::Reply => "ai reply added",
433        AiSessionMode::Refactor => "ai reply added; thread status moved to pending_human",
434    }
435}
436
437#[derive(Debug)]
438struct ParsedAiThreadReply {
439    thread_id: u64,
440    reply: String,
441}
442
443#[derive(Debug, Deserialize)]
444#[serde(deny_unknown_fields)]
445struct AiThreadReplyJson {
446    thread_id: u64,
447    reply: String,
448    status: String,
449}
450
451fn parse_ai_thread_reply_json(
452    raw_reply: &str,
453    expected_thread_id: u64,
454) -> Result<ParsedAiThreadReply> {
455    let json = strip_json_code_fence(raw_reply).trim();
456    let parsed: AiThreadReplyJson = match serde_json::from_str(json) {
457        Ok(parsed) => parsed,
458        Err(error) => {
459            let Some(candidate) = embedded_ai_reply_json_candidate(json) else {
460                return Err(invalid_ai_reply_json_error(error, json));
461            };
462            serde_json::from_str(candidate)
463                .map_err(|error| invalid_ai_reply_json_error(error, candidate))?
464        }
465    };
466
467    if parsed.thread_id != expected_thread_id {
468        return Err(anyhow!(
469            "thread_id {} did not match requested thread {}",
470            parsed.thread_id,
471            expected_thread_id
472        ));
473    }
474
475    if parsed.status != "pending_human" {
476        return Err(anyhow!(
477            "status {:?} did not match required pending_human",
478            parsed.status
479        ));
480    }
481
482    let reply = parsed.reply.trim().to_string();
483    if reply.is_empty() {
484        return Err(anyhow!("reply must not be empty"));
485    }
486
487    Ok(ParsedAiThreadReply {
488        thread_id: parsed.thread_id,
489        reply,
490    })
491}
492
493fn invalid_ai_reply_json_error(error: serde_json::Error, json: &str) -> anyhow::Error {
494    let trimmed = json.trim();
495    if trimmed.is_empty() {
496        return anyhow!(
497            "expected JSON object with thread_id, reply, status: {error}; response was empty"
498        );
499    }
500
501    anyhow!(
502        "expected JSON object with thread_id, reply, status: {error}; response preview: {}",
503        ai_reply_preview(trimmed)
504    )
505}
506
507fn ai_reply_preview(value: &str) -> String {
508    const MAX_PREVIEW_CHARS: usize = 500;
509    let mut preview = value
510        .chars()
511        .take(MAX_PREVIEW_CHARS)
512        .collect::<String>()
513        .replace('\r', "\\r")
514        .replace('\n', "\\n")
515        .replace('\t', "\\t");
516    if value.chars().count() > MAX_PREVIEW_CHARS {
517        preview.push_str("...");
518    }
519    preview
520}
521
522fn strip_json_code_fence(raw_reply: &str) -> &str {
523    let trimmed = raw_reply.trim();
524    if !trimmed.starts_with("```") {
525        return trimmed;
526    }
527
528    let without_start = if let Some(value) = trimmed.strip_prefix("```json") {
529        value
530    } else if let Some(value) = trimmed.strip_prefix("```") {
531        value
532    } else {
533        trimmed
534    };
535
536    let without_start = without_start.trim_start();
537    if let Some(value) = without_start.strip_suffix("```") {
538        value.trim()
539    } else {
540        without_start
541    }
542}
543
544fn embedded_ai_reply_json_candidate(value: &str) -> Option<&str> {
545    let mut search_start = 0;
546    while search_start < value.len() {
547        let start = value.get(search_start..)?.find('{')? + search_start;
548        let end = balanced_json_object_end(value, start)?;
549        let candidate = &value[start..end];
550        if has_ai_reply_schema_keys(candidate) {
551            return Some(candidate);
552        }
553        search_start = end;
554    }
555    None
556}
557
558fn has_ai_reply_schema_keys(candidate: &str) -> bool {
559    let Ok(value) = serde_json::from_str::<serde_json::Value>(candidate) else {
560        return false;
561    };
562    let Some(object) = value.as_object() else {
563        return false;
564    };
565    object.contains_key("thread_id")
566        && object.contains_key("reply")
567        && object.contains_key("status")
568}
569
570fn balanced_json_object_end(value: &str, start: usize) -> Option<usize> {
571    let mut depth = 0usize;
572    let mut in_string = false;
573    let mut escaped = false;
574
575    for (offset, ch) in value.get(start..)?.char_indices() {
576        if in_string {
577            if escaped {
578                escaped = false;
579            } else if ch == '\\' {
580                escaped = true;
581            } else if ch == '"' {
582                in_string = false;
583            }
584            continue;
585        }
586
587        match ch {
588            '"' => in_string = true,
589            '{' => depth = depth.saturating_add(1),
590            '}' => {
591                depth = depth.checked_sub(1)?;
592                if depth == 0 {
593                    return Some(start + offset + ch.len_utf8());
594                }
595            }
596            _ => {}
597        }
598    }
599
600    None
601}
602
603fn comment_is_targetable(status: CommentStatus) -> bool {
604    matches!(status, CommentStatus::Open | CommentStatus::Pending)
605}