1use self::progress::emit_progress;
2use self::prompt::{build_thread_prompt, load_task_prompt_override};
3use self::provider::{format_ai_reply_body, invoke_provider};
4use crate::domain::ai::{AiProvider, AiSessionMode};
5use crate::domain::config::{AgentTransport, AiProviderConfig, AppConfig};
6use crate::domain::diff::DiffDocument;
7use crate::domain::review::{Author, CommentStatus, ReviewSession};
8use crate::git::diff::{DiffSource, load_git_diff};
9use crate::services::review_service::{AddReplyInput, ReviewService};
10use crate::utils::time::now_ms;
11use anyhow::{Result, anyhow};
12use serde::{Deserialize, Serialize};
13use tokio::sync::mpsc;
14use tracing::{debug, error, info, warn};
15
16pub(crate) mod json_text;
17mod progress;
18mod prompt;
19mod provider;
20
21#[cfg(test)]
22mod tests;
23
24#[derive(Debug, Clone)]
25pub struct RunAiSessionInput {
26 pub review_name: String,
27 pub provider: AiProvider,
28 pub transport: Option<AgentTransport>,
29 pub comment_ids: Vec<u64>,
30 pub mode: AiSessionMode,
31 pub diff_source: DiffSource,
32}
33
34#[derive(Debug, Clone, Serialize)]
35#[serde(rename_all = "snake_case")]
36pub struct AiSessionResult {
37 pub review_name: String,
38 pub provider: String,
39 pub mode: String,
40 pub transport: String,
41 pub client: String,
42 pub model: Option<String>,
43 pub session_id: String,
44 pub processed: usize,
45 pub skipped: usize,
46 pub failed: usize,
47 pub items: Vec<AiSessionItemResult>,
48}
49
50#[derive(Debug, Clone, Serialize)]
51#[serde(rename_all = "snake_case")]
52pub struct AiSessionItemResult {
53 pub comment_id: u64,
54 pub status: String,
55 pub message: String,
56}
57
58impl AiSessionResult {
59 fn new(input: &RunAiSessionInput, provider_cfg: &AiProviderConfig, now_ms: u64) -> Self {
60 Self {
61 review_name: input.review_name.clone(),
62 provider: input.provider.as_str().to_string(),
63 mode: input.mode.as_str().to_string(),
64 transport: provider_cfg.transport.as_str().to_string(),
65 client: provider_cfg.client.clone(),
66 model: provider_cfg.model.clone(),
67 session_id: format!("{}-{}-{now_ms}", input.review_name, input.provider.as_str()),
68 processed: 0,
69 skipped: 0,
70 failed: 0,
71 items: Vec::new(),
72 }
73 }
74
75 fn push_processed(&mut self, comment_id: u64, message: impl Into<String>) {
76 self.processed += 1;
77 self.push_item(comment_id, "processed", message);
78 }
79
80 fn push_skipped(&mut self, comment_id: u64, message: impl Into<String>) {
81 self.skipped += 1;
82 self.push_item(comment_id, "skipped", message);
83 }
84
85 fn push_failed(&mut self, comment_id: u64, message: impl Into<String>) {
86 self.failed += 1;
87 self.push_item(comment_id, "failed", message);
88 }
89
90 fn push_item(&mut self, comment_id: u64, status: &str, message: impl Into<String>) {
91 self.items.push(AiSessionItemResult {
92 comment_id,
93 status: status.to_string(),
94 message: message.into(),
95 });
96 }
97}
98
99#[derive(Debug, Clone, Serialize)]
100#[serde(rename_all = "snake_case")]
101pub struct AiProgressEvent {
102 pub timestamp_ms: u64,
103 pub provider: String,
104 pub stream: String,
105 pub message: String,
106}
107
108#[must_use]
109pub fn default_ai_session_mode(comment_ids: &[u64]) -> AiSessionMode {
110 if comment_ids.is_empty() {
111 AiSessionMode::Refactor
112 } else {
113 AiSessionMode::Reply
114 }
115}
116
117pub async fn run_ai_session(
122 service: &ReviewService,
123 input: RunAiSessionInput,
124) -> Result<AiSessionResult> {
125 run_ai_session_inner(service, input, None).await
126}
127
128pub async fn run_ai_session_with_progress(
133 service: &ReviewService,
134 input: RunAiSessionInput,
135 progress_sender: mpsc::UnboundedSender<AiProgressEvent>,
136) -> Result<AiSessionResult> {
137 run_ai_session_inner(service, input, Some(progress_sender)).await
138}
139
140async fn run_ai_session_inner(
141 service: &ReviewService,
142 input: RunAiSessionInput,
143 progress_sender: Option<mpsc::UnboundedSender<AiProgressEvent>>,
144) -> Result<AiSessionResult> {
145 info!(
146 review = %input.review_name,
147 provider = %input.provider.as_str(),
148 requested_comments = input.comment_ids.len(),
149 "starting ai session"
150 );
151 let config = service.load_config().await?;
152 let mut review = service.load_review(&input.review_name).await?;
153 let diff_document = match load_git_diff(&config, &input.diff_source).await {
154 Ok(document) => Some(document),
155 Err(error) => {
156 warn!(error = %error, "ai session prompt context: unable to load git diff");
157 None
158 }
159 };
160 let now_ms = now_ms()?;
161 let effective_transport = input.transport.or(config.ai.default_transport);
162 let provider_cfg = config
163 .ai
164 .provider_config_for_transport(input.provider, effective_transport);
165 let mut result = AiSessionResult::new(&input, &provider_cfg, now_ms);
166
167 let target_ids = ai_session_target_ids(&review, &input.comment_ids);
168 let total_targets = target_ids.len();
169 if total_targets == 0 {
170 result.push_skipped(0, no_targets_message(input.mode));
171 emit_progress(
172 progress_sender.as_ref(),
173 input.provider,
174 "system",
175 "no open threads to process",
176 );
177 return Ok(result);
178 }
179
180 let task_prompt_override = load_task_prompt_override(&config, input.mode).await?;
181 let context = AiSessionExecutionContext {
182 service,
183 config: &config,
184 input: &input,
185 diff_document: diff_document.as_ref(),
186 task_prompt_override: task_prompt_override.as_deref(),
187 progress_sender,
188 };
189 process_ai_session_targets(&context, &mut review, &mut result, target_ids).await?;
190
191 info!(
192 review = %input.review_name,
193 provider = %input.provider.as_str(),
194 processed = result.processed,
195 skipped = result.skipped,
196 failed = result.failed,
197 "ai session completed"
198 );
199 Ok(result)
200}
201
202struct AiSessionExecutionContext<'a> {
203 service: &'a ReviewService,
204 config: &'a AppConfig,
205 input: &'a RunAiSessionInput,
206 diff_document: Option<&'a DiffDocument>,
207 task_prompt_override: Option<&'a str>,
208 progress_sender: Option<mpsc::UnboundedSender<AiProgressEvent>>,
209}
210
211async fn process_ai_session_targets(
212 context: &AiSessionExecutionContext<'_>,
213 review: &mut ReviewSession,
214 result: &mut AiSessionResult,
215 target_ids: Vec<u64>,
216) -> Result<()> {
217 let total_targets = target_ids.len();
218 for (step_index, comment_id) in target_ids.into_iter().enumerate() {
219 let step_number = step_index + 1;
220 emit_progress(
221 context.progress_sender.as_ref(),
222 context.input.provider,
223 "system",
224 format!(
225 "thread #{comment_id}: start ({}/{})",
226 step_number, total_targets
227 ),
228 );
229 debug!(
230 review = %context.input.review_name,
231 provider = %context.input.provider.as_str(),
232 comment_id,
233 "processing ai thread"
234 );
235 process_ai_session_target(
236 context,
237 review,
238 result,
239 comment_id,
240 step_number,
241 total_targets,
242 )
243 .await?;
244 }
245
246 Ok(())
247}
248
249async fn process_ai_session_target(
250 context: &AiSessionExecutionContext<'_>,
251 review: &mut ReviewSession,
252 result: &mut AiSessionResult,
253 comment_id: u64,
254 step_number: usize,
255 total_targets: usize,
256) -> Result<()> {
257 let Some(comment_status) = comment_status(review, comment_id) else {
258 warn!(
259 review = %context.input.review_name,
260 provider = %context.input.provider.as_str(),
261 comment_id,
262 "ai session target comment not found"
263 );
264 result.push_failed(comment_id, "comment not found in review");
265 emit_progress(
266 context.progress_sender.as_ref(),
267 context.input.provider,
268 "system",
269 format!("thread #{comment_id}: failed (comment not found)"),
270 );
271 return Ok(());
272 };
273
274 if !comment_is_targetable(comment_status.clone()) {
275 debug!(
276 review = %context.input.review_name,
277 provider = %context.input.provider.as_str(),
278 comment_id,
279 status = ?comment_status,
280 "skipping non-targetable comment for selected mode"
281 );
282 result.push_skipped(
283 comment_id,
284 format!(
285 "comment status {:?} is not targetable for {} mode",
286 comment_status,
287 context.input.mode.as_str()
288 ),
289 );
290 emit_progress(
291 context.progress_sender.as_ref(),
292 context.input.provider,
293 "system",
294 format!("thread #{comment_id}: skipped (status={comment_status:?})"),
295 );
296 return Ok(());
297 }
298
299 let prompt = build_thread_prompt(
300 &context.input.review_name,
301 comment_id,
302 review,
303 context.diff_document,
304 context.input.mode,
305 context.task_prompt_override,
306 )
307 .await?;
308 let provider_reply = match invoke_provider(
309 context.config,
310 context.input.provider,
311 context.input.transport,
312 context.input.mode,
313 &prompt,
314 context.progress_sender.clone(),
315 )
316 .await
317 {
318 Ok(reply) => reply,
319 Err(error) => {
320 error!(
321 review = %context.input.review_name,
322 provider = %context.input.provider.as_str(),
323 comment_id,
324 error = %error,
325 "provider invocation failed"
326 );
327 result.push_failed(comment_id, format!("provider failed: {error}"));
328 emit_progress(
329 context.progress_sender.as_ref(),
330 context.input.provider,
331 "system",
332 format!("thread #{comment_id}: failed ({error})"),
333 );
334 return Ok(());
335 }
336 };
337 let parsed_reply = match parse_ai_thread_reply_json(&provider_reply.reply, comment_id) {
338 Ok(parsed_reply) => parsed_reply,
339 Err(error) => {
340 result.push_failed(comment_id, format!("invalid AI reply JSON: {error}"));
341 emit_progress(
342 context.progress_sender.as_ref(),
343 context.input.provider,
344 "system",
345 format!("thread #{comment_id}: failed (invalid AI reply JSON: {error})"),
346 );
347 return Ok(());
348 }
349 };
350 let reply_body = format_ai_reply_body(provider_reply.model.as_deref(), &parsed_reply.reply);
351
352 *review = match context
353 .service
354 .add_reply(
355 &context.input.review_name,
356 AddReplyInput {
357 comment_id: parsed_reply.thread_id,
358 author: Author::Ai,
359 body: reply_body,
360 },
361 )
362 .await
363 {
364 Ok(updated) => updated,
365 Err(error) => {
366 error!(
367 review = %context.input.review_name,
368 provider = %context.input.provider.as_str(),
369 comment_id,
370 error = %error,
371 "failed to persist ai reply"
372 );
373 result.push_failed(comment_id, format!("failed to persist ai reply: {error}"));
374 emit_progress(
375 context.progress_sender.as_ref(),
376 context.input.provider,
377 "system",
378 format!("thread #{comment_id}: failed (persist reply: {error})"),
379 );
380 return Ok(());
381 }
382 };
383
384 info!(
385 review = %context.input.review_name,
386 provider = %context.input.provider.as_str(),
387 comment_id,
388 "ai reply persisted"
389 );
390 result.push_processed(comment_id, processed_target_message(context.input.mode));
391 emit_progress(
392 context.progress_sender.as_ref(),
393 context.input.provider,
394 "system",
395 format!(
396 "thread #{comment_id}: reply persisted; status pending_human ({step_number}/{total_targets})"
397 ),
398 );
399 Ok(())
400}
401
402fn ai_session_target_ids(review: &ReviewSession, comment_ids: &[u64]) -> Vec<u64> {
403 if !comment_ids.is_empty() {
404 return comment_ids.to_vec();
405 }
406
407 review
408 .comments
409 .iter()
410 .filter(|comment| comment_is_targetable(comment.status.clone()))
411 .map(|comment| comment.id)
412 .collect()
413}
414
415fn comment_status(review: &ReviewSession, comment_id: u64) -> Option<CommentStatus> {
416 review
417 .comments
418 .iter()
419 .find(|comment| comment.id == comment_id)
420 .map(|comment| comment.status.clone())
421}
422
423fn no_targets_message(mode: AiSessionMode) -> &'static str {
424 match mode {
425 AiSessionMode::Reply => "no replyable threads to process",
426 AiSessionMode::Refactor => "no open threads to process",
427 }
428}
429
430fn processed_target_message(mode: AiSessionMode) -> &'static str {
431 match mode {
432 AiSessionMode::Reply => "ai reply added",
433 AiSessionMode::Refactor => "ai reply added; thread status moved to pending_human",
434 }
435}
436
437#[derive(Debug)]
438struct ParsedAiThreadReply {
439 thread_id: u64,
440 reply: String,
441}
442
443#[derive(Debug, Deserialize)]
444#[serde(deny_unknown_fields)]
445struct AiThreadReplyJson {
446 thread_id: u64,
447 reply: String,
448 status: String,
449}
450
451fn parse_ai_thread_reply_json(
452 raw_reply: &str,
453 expected_thread_id: u64,
454) -> Result<ParsedAiThreadReply> {
455 let json = strip_json_code_fence(raw_reply).trim();
456 let parsed: AiThreadReplyJson = match serde_json::from_str(json) {
457 Ok(parsed) => parsed,
458 Err(error) => {
459 let Some(candidate) = embedded_ai_reply_json_candidate(json) else {
460 return Err(invalid_ai_reply_json_error(error, json));
461 };
462 serde_json::from_str(candidate)
463 .map_err(|error| invalid_ai_reply_json_error(error, candidate))?
464 }
465 };
466
467 if parsed.thread_id != expected_thread_id {
468 return Err(anyhow!(
469 "thread_id {} did not match requested thread {}",
470 parsed.thread_id,
471 expected_thread_id
472 ));
473 }
474
475 if parsed.status != "pending_human" {
476 return Err(anyhow!(
477 "status {:?} did not match required pending_human",
478 parsed.status
479 ));
480 }
481
482 let reply = parsed.reply.trim().to_string();
483 if reply.is_empty() {
484 return Err(anyhow!("reply must not be empty"));
485 }
486
487 Ok(ParsedAiThreadReply {
488 thread_id: parsed.thread_id,
489 reply,
490 })
491}
492
493fn invalid_ai_reply_json_error(error: serde_json::Error, json: &str) -> anyhow::Error {
494 let trimmed = json.trim();
495 if trimmed.is_empty() {
496 return anyhow!(
497 "expected JSON object with thread_id, reply, status: {error}; response was empty"
498 );
499 }
500
501 anyhow!(
502 "expected JSON object with thread_id, reply, status: {error}; response preview: {}",
503 ai_reply_preview(trimmed)
504 )
505}
506
507fn ai_reply_preview(value: &str) -> String {
508 const MAX_PREVIEW_CHARS: usize = 500;
509 let mut preview = value
510 .chars()
511 .take(MAX_PREVIEW_CHARS)
512 .collect::<String>()
513 .replace('\r', "\\r")
514 .replace('\n', "\\n")
515 .replace('\t', "\\t");
516 if value.chars().count() > MAX_PREVIEW_CHARS {
517 preview.push_str("...");
518 }
519 preview
520}
521
522fn strip_json_code_fence(raw_reply: &str) -> &str {
523 let trimmed = raw_reply.trim();
524 if !trimmed.starts_with("```") {
525 return trimmed;
526 }
527
528 let without_start = if let Some(value) = trimmed.strip_prefix("```json") {
529 value
530 } else if let Some(value) = trimmed.strip_prefix("```") {
531 value
532 } else {
533 trimmed
534 };
535
536 let without_start = without_start.trim_start();
537 if let Some(value) = without_start.strip_suffix("```") {
538 value.trim()
539 } else {
540 without_start
541 }
542}
543
544fn embedded_ai_reply_json_candidate(value: &str) -> Option<&str> {
545 let mut search_start = 0;
546 while search_start < value.len() {
547 let start = value.get(search_start..)?.find('{')? + search_start;
548 let end = balanced_json_object_end(value, start)?;
549 let candidate = &value[start..end];
550 if has_ai_reply_schema_keys(candidate) {
551 return Some(candidate);
552 }
553 search_start = end;
554 }
555 None
556}
557
558fn has_ai_reply_schema_keys(candidate: &str) -> bool {
559 let Ok(value) = serde_json::from_str::<serde_json::Value>(candidate) else {
560 return false;
561 };
562 let Some(object) = value.as_object() else {
563 return false;
564 };
565 object.contains_key("thread_id")
566 && object.contains_key("reply")
567 && object.contains_key("status")
568}
569
570fn balanced_json_object_end(value: &str, start: usize) -> Option<usize> {
571 let mut depth = 0usize;
572 let mut in_string = false;
573 let mut escaped = false;
574
575 for (offset, ch) in value.get(start..)?.char_indices() {
576 if in_string {
577 if escaped {
578 escaped = false;
579 } else if ch == '\\' {
580 escaped = true;
581 } else if ch == '"' {
582 in_string = false;
583 }
584 continue;
585 }
586
587 match ch {
588 '"' => in_string = true,
589 '{' => depth = depth.saturating_add(1),
590 '}' => {
591 depth = depth.checked_sub(1)?;
592 if depth == 0 {
593 return Some(start + offset + ch.len_utf8());
594 }
595 }
596 _ => {}
597 }
598 }
599
600 None
601}
602
603fn comment_is_targetable(status: CommentStatus) -> bool {
604 matches!(status, CommentStatus::Open | CommentStatus::Pending)
605}