1use self::progress::emit_progress;
2use self::prompt::{build_thread_prompt, load_task_prompt_override};
3use self::provider::{format_ai_reply_body, invoke_provider};
4use crate::domain::ai::{AiProvider, AiSessionMode};
5use crate::domain::config::{AgentTransport, AiProviderConfig, AppConfig};
6use crate::domain::diff::DiffDocument;
7use crate::domain::review::{Author, CommentStatus, ReviewSession};
8use crate::git::diff::{DiffSource, load_git_diff};
9use crate::services::review_service::{AddReplyInput, ReviewService};
10use crate::utils::time::now_ms;
11use anyhow::{Result, anyhow};
12use serde::{Deserialize, Serialize};
13use tokio::sync::mpsc;
14use tracing::{debug, error, info, warn};
15
16pub(crate) mod json_text;
17mod progress;
18mod prompt;
19mod provider;
20
21#[cfg(test)]
22mod tests;
23
24use std::path::PathBuf;
25
26#[derive(Debug, Clone)]
27pub struct RunAiSessionInput {
28 pub review_name: String,
29 pub provider: AiProvider,
30 pub transport: Option<AgentTransport>,
31 pub comment_ids: Vec<u64>,
32 pub mode: AiSessionMode,
33 pub diff_source: DiffSource,
34 pub worktree_path: Option<PathBuf>,
35}
36
37#[derive(Debug, Clone, Serialize)]
38#[serde(rename_all = "snake_case")]
39pub struct AiSessionResult {
40 pub review_name: String,
41 pub provider: String,
42 pub mode: String,
43 pub transport: String,
44 pub client: String,
45 pub model: Option<String>,
46 pub session_id: String,
47 pub processed: usize,
48 pub skipped: usize,
49 pub failed: usize,
50 pub items: Vec<AiSessionItemResult>,
51}
52
53#[derive(Debug, Clone, Serialize)]
54#[serde(rename_all = "snake_case")]
55pub struct AiSessionItemResult {
56 pub comment_id: u64,
57 pub status: String,
58 pub message: String,
59}
60
61impl AiSessionResult {
62 fn new(input: &RunAiSessionInput, provider_cfg: &AiProviderConfig, now_ms: u64) -> Self {
63 Self {
64 review_name: input.review_name.clone(),
65 provider: input.provider.as_str().to_string(),
66 mode: input.mode.as_str().to_string(),
67 transport: provider_cfg.transport.as_str().to_string(),
68 client: provider_cfg.client.clone(),
69 model: provider_cfg.model.clone(),
70 session_id: format!("{}-{}-{now_ms}", input.review_name, input.provider.as_str()),
71 processed: 0,
72 skipped: 0,
73 failed: 0,
74 items: Vec::new(),
75 }
76 }
77
78 fn push_processed(&mut self, comment_id: u64, message: impl Into<String>) {
79 self.processed += 1;
80 self.push_item(comment_id, "processed", message);
81 }
82
83 fn push_skipped(&mut self, comment_id: u64, message: impl Into<String>) {
84 self.skipped += 1;
85 self.push_item(comment_id, "skipped", message);
86 }
87
88 fn push_failed(&mut self, comment_id: u64, message: impl Into<String>) {
89 self.failed += 1;
90 self.push_item(comment_id, "failed", message);
91 }
92
93 fn push_item(&mut self, comment_id: u64, status: &str, message: impl Into<String>) {
94 self.items.push(AiSessionItemResult {
95 comment_id,
96 status: status.to_string(),
97 message: message.into(),
98 });
99 }
100}
101
102#[derive(Debug, Clone, Serialize)]
103#[serde(rename_all = "snake_case")]
104pub struct AiProgressEvent {
105 pub timestamp_ms: u64,
106 pub provider: String,
107 pub stream: String,
108 pub message: String,
109}
110
111#[must_use]
112pub fn default_ai_session_mode(comment_ids: &[u64]) -> AiSessionMode {
113 if comment_ids.is_empty() {
114 AiSessionMode::Refactor
115 } else {
116 AiSessionMode::Reply
117 }
118}
119
120pub async fn run_ai_session(
125 service: &ReviewService,
126 input: RunAiSessionInput,
127) -> Result<AiSessionResult> {
128 run_ai_session_inner(service, input, None).await
129}
130
131pub async fn run_ai_session_with_progress(
136 service: &ReviewService,
137 input: RunAiSessionInput,
138 progress_sender: mpsc::UnboundedSender<AiProgressEvent>,
139) -> Result<AiSessionResult> {
140 run_ai_session_inner(service, input, Some(progress_sender)).await
141}
142
143async fn run_ai_session_inner(
144 service: &ReviewService,
145 input: RunAiSessionInput,
146 progress_sender: Option<mpsc::UnboundedSender<AiProgressEvent>>,
147) -> Result<AiSessionResult> {
148 info!(
149 review = %input.review_name,
150 provider = %input.provider.as_str(),
151 requested_comments = input.comment_ids.len(),
152 "starting ai session"
153 );
154 let config = service.load_config().await?;
155 let mut review = service.load_review(&input.review_name).await?;
156 let worktree = input
157 .worktree_path
158 .as_deref()
159 .unwrap_or_else(|| std::path::Path::new("."));
160 let diff_document = match load_git_diff(&config, &input.diff_source, worktree).await {
161 Ok(document) => Some(document),
162 Err(error) => {
163 warn!(error = %error, "ai session prompt context: unable to load git diff");
164 None
165 }
166 };
167 let now_ms = now_ms()?;
168 let effective_transport = input.transport.or(config.ai.default_transport);
169 let provider_cfg = config
170 .ai
171 .provider_config_for_transport(input.provider, effective_transport);
172 let mut result = AiSessionResult::new(&input, &provider_cfg, now_ms);
173
174 let target_ids = ai_session_target_ids(&review, &input.comment_ids);
175 let total_targets = target_ids.len();
176 if total_targets == 0 {
177 result.push_skipped(0, no_targets_message(input.mode));
178 emit_progress(
179 progress_sender.as_ref(),
180 input.provider,
181 "system",
182 "no open threads to process",
183 );
184 return Ok(result);
185 }
186
187 let task_prompt_override = load_task_prompt_override(&config, input.mode).await?;
188 let context = AiSessionExecutionContext {
189 service,
190 config: &config,
191 input: &input,
192 diff_document: diff_document.as_ref(),
193 task_prompt_override: task_prompt_override.as_deref(),
194 progress_sender,
195 };
196 process_ai_session_targets(&context, &mut review, &mut result, target_ids).await?;
197
198 info!(
199 review = %input.review_name,
200 provider = %input.provider.as_str(),
201 processed = result.processed,
202 skipped = result.skipped,
203 failed = result.failed,
204 "ai session completed"
205 );
206 Ok(result)
207}
208
209struct AiSessionExecutionContext<'a> {
210 service: &'a ReviewService,
211 config: &'a AppConfig,
212 input: &'a RunAiSessionInput,
213 diff_document: Option<&'a DiffDocument>,
214 task_prompt_override: Option<&'a str>,
215 progress_sender: Option<mpsc::UnboundedSender<AiProgressEvent>>,
216}
217
218async fn process_ai_session_targets(
219 context: &AiSessionExecutionContext<'_>,
220 review: &mut ReviewSession,
221 result: &mut AiSessionResult,
222 target_ids: Vec<u64>,
223) -> Result<()> {
224 let total_targets = target_ids.len();
225 for (step_index, comment_id) in target_ids.into_iter().enumerate() {
226 let step_number = step_index + 1;
227 emit_progress(
228 context.progress_sender.as_ref(),
229 context.input.provider,
230 "system",
231 format!("thread #{comment_id}: start ({step_number}/{total_targets})"),
232 );
233 debug!(
234 review = %context.input.review_name,
235 provider = %context.input.provider.as_str(),
236 comment_id,
237 "processing ai thread"
238 );
239 process_ai_session_target(
240 context,
241 review,
242 result,
243 comment_id,
244 step_number,
245 total_targets,
246 )
247 .await?;
248 }
249
250 Ok(())
251}
252
253async fn process_ai_session_target(
254 context: &AiSessionExecutionContext<'_>,
255 review: &mut ReviewSession,
256 result: &mut AiSessionResult,
257 comment_id: u64,
258 step_number: usize,
259 total_targets: usize,
260) -> Result<()> {
261 let opt_status = comment_status(review, comment_id);
262 if opt_status.is_none() {
263 warn!(
264 review = %context.input.review_name,
265 provider = %context.input.provider.as_str(),
266 comment_id,
267 "ai session target comment not found"
268 );
269 result.push_failed(comment_id, "comment not found in review");
270 emit_progress(
271 context.progress_sender.as_ref(),
272 context.input.provider,
273 "system",
274 format!("thread #{comment_id}: failed (comment not found)"),
275 );
276 return Ok(());
277 }
278 let comment_status = opt_status.unwrap();
279
280 if !comment_is_targetable(comment_status) {
281 debug!(
282 review = %context.input.review_name,
283 provider = %context.input.provider.as_str(),
284 comment_id,
285 status = ?comment_status,
286 "skipping non-targetable comment for selected mode"
287 );
288 result.push_skipped(
289 comment_id,
290 format!(
291 "comment status {:?} is not targetable for {} mode",
292 comment_status,
293 context.input.mode.as_str()
294 ),
295 );
296 emit_progress(
297 context.progress_sender.as_ref(),
298 context.input.provider,
299 "system",
300 format!("thread #{comment_id}: skipped (status={comment_status:?})"),
301 );
302 return Ok(());
303 }
304
305 let prompt = build_thread_prompt(
306 &context.input.review_name,
307 comment_id,
308 review,
309 context.diff_document,
310 context.input.mode,
311 context.task_prompt_override,
312 )
313 .await?;
314 let provider_reply = match invoke_provider(
315 context.config,
316 context.input.provider,
317 context.input.transport,
318 context.input.mode,
319 &prompt,
320 context.progress_sender.clone(),
321 context.input.worktree_path.as_deref(),
322 )
323 .await
324 {
325 Ok(reply) => reply,
326 Err(error) => {
327 error!(
328 review = %context.input.review_name,
329 provider = %context.input.provider.as_str(),
330 comment_id,
331 error = %error,
332 "provider invocation failed"
333 );
334 result.push_failed(comment_id, format!("provider failed: {error}"));
335 emit_progress(
336 context.progress_sender.as_ref(),
337 context.input.provider,
338 "system",
339 format!("thread #{comment_id}: failed ({error})"),
340 );
341 return Ok(());
342 }
343 };
344 let parsed_reply = match parse_ai_thread_reply_json(&provider_reply.reply, comment_id) {
345 Ok(parsed_reply) => parsed_reply,
346 Err(error) => {
347 result.push_failed(comment_id, format!("invalid AI reply JSON: {error}"));
348 emit_progress(
349 context.progress_sender.as_ref(),
350 context.input.provider,
351 "system",
352 format!("thread #{comment_id}: failed (invalid AI reply JSON: {error})"),
353 );
354 return Ok(());
355 }
356 };
357 let reply_body = format_ai_reply_body(provider_reply.model.as_deref(), &parsed_reply.reply);
358
359 *review = match context
360 .service
361 .add_reply(
362 &context.input.review_name,
363 AddReplyInput {
364 comment_id: parsed_reply.thread_id,
365 author: Author::Ai,
366 body: reply_body,
367 },
368 )
369 .await
370 {
371 Ok(updated) => updated,
372 Err(error) => {
373 error!(
374 review = %context.input.review_name,
375 provider = %context.input.provider.as_str(),
376 comment_id,
377 error = %error,
378 "failed to persist ai reply"
379 );
380 result.push_failed(comment_id, format!("failed to persist ai reply: {error}"));
381 emit_progress(
382 context.progress_sender.as_ref(),
383 context.input.provider,
384 "system",
385 format!("thread #{comment_id}: failed (persist reply: {error})"),
386 );
387 return Ok(());
388 }
389 };
390
391 info!(
392 review = %context.input.review_name,
393 provider = %context.input.provider.as_str(),
394 comment_id,
395 "ai reply persisted"
396 );
397 result.push_processed(comment_id, processed_target_message(context.input.mode));
398 emit_progress(
399 context.progress_sender.as_ref(),
400 context.input.provider,
401 "system",
402 format!(
403 "thread #{comment_id}: reply persisted; status pending_human ({step_number}/{total_targets})"
404 ),
405 );
406 Ok(())
407}
408
409fn ai_session_target_ids(review: &ReviewSession, comment_ids: &[u64]) -> Vec<u64> {
410 if !comment_ids.is_empty() {
411 return comment_ids.to_vec();
412 }
413
414 review
415 .comments
416 .iter()
417 .filter(|comment| comment_is_targetable(&comment.status))
418 .map(|comment| comment.id)
419 .collect()
420}
421
422fn comment_status(review: &ReviewSession, comment_id: u64) -> Option<&CommentStatus> {
423 review.comments.iter().find_map(|comment| {
424 if comment.id == comment_id {
425 Some(&comment.status)
426 } else {
427 None
428 }
429 })
430}
431
432fn no_targets_message(mode: AiSessionMode) -> &'static str {
433 match mode {
434 AiSessionMode::Reply => "no replyable threads to process",
435 AiSessionMode::Refactor => "no open threads to process",
436 }
437}
438
439fn processed_target_message(mode: AiSessionMode) -> &'static str {
440 match mode {
441 AiSessionMode::Reply => "ai reply added",
442 AiSessionMode::Refactor => "ai reply added; thread status moved to pending_human",
443 }
444}
445
446#[derive(Debug)]
447struct ParsedAiThreadReply {
448 thread_id: u64,
449 reply: String,
450}
451
452#[derive(Debug, Deserialize)]
453#[serde(deny_unknown_fields)]
454struct AiThreadReplyJson {
455 thread_id: u64,
456 reply: String,
457 status: String,
458}
459
460fn parse_ai_thread_reply_json(
461 raw_reply: &str,
462 expected_thread_id: u64,
463) -> Result<ParsedAiThreadReply> {
464 let json = strip_json_code_fence(raw_reply).trim();
465 let parsed: AiThreadReplyJson = match serde_json::from_str(json) {
466 Ok(parsed) => parsed,
467 Err(error) => {
468 let Some(candidate) = embedded_ai_reply_json_candidate(json) else {
469 return Err(invalid_ai_reply_json_error(error, json));
470 };
471 serde_json::from_str(candidate)
472 .map_err(|error| invalid_ai_reply_json_error(error, candidate))?
473 }
474 };
475
476 if parsed.thread_id != expected_thread_id {
477 return Err(anyhow!(
478 "thread_id {} did not match requested thread {}",
479 parsed.thread_id,
480 expected_thread_id
481 ));
482 }
483
484 if parsed.status != "pending_human" {
485 return Err(anyhow!(
486 "status {:?} did not match required pending_human",
487 parsed.status
488 ));
489 }
490
491 let reply = parsed.reply.trim().to_string();
492 if reply.is_empty() {
493 return Err(anyhow!("reply must not be empty"));
494 }
495
496 Ok(ParsedAiThreadReply {
497 thread_id: parsed.thread_id,
498 reply,
499 })
500}
501
502fn invalid_ai_reply_json_error(error: serde_json::Error, json: &str) -> anyhow::Error {
503 let trimmed = json.trim();
504 if trimmed.is_empty() {
505 return anyhow!(
506 "expected JSON object with thread_id, reply, status: {error}; response was empty"
507 );
508 }
509
510 anyhow!(
511 "expected JSON object with thread_id, reply, status: {error}; response preview: {}",
512 ai_reply_preview(trimmed)
513 )
514}
515
516fn ai_reply_preview(value: &str) -> String {
517 const MAX_PREVIEW_CHARS: usize = 500;
518 let mut preview = value
519 .chars()
520 .take(MAX_PREVIEW_CHARS)
521 .collect::<String>()
522 .replace('\r', "\\r")
523 .replace('\n', "\\n")
524 .replace('\t', "\\t");
525 if value.chars().count() > MAX_PREVIEW_CHARS {
526 preview.push_str("...");
527 }
528 preview
529}
530
531fn strip_json_code_fence(raw_reply: &str) -> &str {
532 let trimmed = raw_reply.trim();
533 if !trimmed.starts_with("```") {
534 return trimmed;
535 }
536
537 let without_start = if let Some(value) = trimmed.strip_prefix("```json") {
538 value
539 } else if let Some(value) = trimmed.strip_prefix("```") {
540 value
541 } else {
542 trimmed
543 };
544
545 let without_start = without_start.trim_start();
546 if let Some(value) = without_start.strip_suffix("```") {
547 value.trim()
548 } else {
549 without_start
550 }
551}
552
553fn embedded_ai_reply_json_candidate(value: &str) -> Option<&str> {
554 let mut search_start = 0;
555 while search_start < value.len() {
556 let start = value.get(search_start..)?.find('{')? + search_start;
557 let end = balanced_json_object_end(value, start)?;
558 let candidate = &value[start..end];
559 if has_ai_reply_schema_keys(candidate) {
560 return Some(candidate);
561 }
562 search_start = end;
563 }
564 None
565}
566
567fn has_ai_reply_schema_keys(candidate: &str) -> bool {
568 let Ok(value) = serde_json::from_str::<serde_json::Value>(candidate) else {
569 return false;
570 };
571 let Some(object) = value.as_object() else {
572 return false;
573 };
574 object.contains_key("thread_id")
575 && object.contains_key("reply")
576 && object.contains_key("status")
577}
578
579fn balanced_json_object_end(value: &str, start: usize) -> Option<usize> {
580 let mut depth = 0usize;
581 let mut in_string = false;
582 let mut escaped = false;
583
584 for (offset, ch) in value.get(start..)?.char_indices() {
585 if in_string {
586 if escaped {
587 escaped = false;
588 } else if ch == '\\' {
589 escaped = true;
590 } else if ch == '"' {
591 in_string = false;
592 }
593 continue;
594 }
595
596 match ch {
597 '"' => in_string = true,
598 '{' => depth = depth.saturating_add(1),
599 '}' => {
600 depth = depth.checked_sub(1)?;
601 if depth == 0 {
602 return Some(start + offset + ch.len_utf8());
603 }
604 }
605 _ => {}
606 }
607 }
608
609 None
610}
611
612fn comment_is_targetable(status: &CommentStatus) -> bool {
613 matches!(status, CommentStatus::Open | CommentStatus::Pending)
614}