1use std::cell::RefCell;
49use std::time::Duration;
50
51use crate::llm::api::{DeltaSender, LlmRequestPayload, LlmResult, ProviderTelemetry};
52use crate::llm::provider::{LlmProvider, LlmProviderChat};
53use crate::value::{ErrorCategory, VmError};
54
55#[derive(Clone, Debug)]
57pub enum FakeLlmTurn {
58 Stream(Vec<FakeLlmEvent>),
61 Error(FakeLlmError),
64 Stalled(Duration),
68}
69
70impl FakeLlmTurn {
71 pub fn stream(events: impl IntoIterator<Item = FakeLlmEvent>) -> Self {
73 Self::Stream(events.into_iter().collect())
74 }
75
76 pub fn error(category: ErrorCategory, message: impl Into<String>) -> Self {
78 Self::Error(FakeLlmError {
79 category,
80 message: message.into(),
81 retry_after_ms: None,
82 })
83 }
84}
85
86#[derive(Clone, Debug)]
88pub enum FakeLlmEvent {
89 Token(String),
92 ToolCallDelta {
95 id: String,
96 name: String,
97 arguments: serde_json::Value,
98 },
99 Stall(Duration),
102 Done(FakeStopReason),
105}
106
107#[derive(Clone, Debug, PartialEq, Eq)]
109pub enum FakeStopReason {
110 EndTurn,
111 ToolUse,
112 MaxTokens,
113 StopSequence,
114 Custom(String),
115}
116
117impl FakeStopReason {
118 fn as_str(&self) -> &str {
119 match self {
120 Self::EndTurn => "end_turn",
121 Self::ToolUse => "tool_use",
122 Self::MaxTokens => "max_tokens",
123 Self::StopSequence => "stop_sequence",
124 Self::Custom(value) => value.as_str(),
125 }
126 }
127}
128
129#[derive(Clone, Debug)]
131pub struct FakeLlmError {
132 pub category: ErrorCategory,
133 pub message: String,
134 pub retry_after_ms: Option<u64>,
138}
139
140impl FakeLlmError {
141 pub fn new(category: ErrorCategory, message: impl Into<String>) -> Self {
142 Self {
143 category,
144 message: message.into(),
145 retry_after_ms: None,
146 }
147 }
148
149 pub fn with_retry_after_ms(mut self, ms: u64) -> Self {
150 self.retry_after_ms = Some(ms);
151 self
152 }
153}
154
155#[derive(Clone, Debug, Default)]
162pub struct FakeLlmScript {
163 pub turns: Vec<FakeLlmTurn>,
164}
165
166impl FakeLlmScript {
167 pub fn new() -> Self {
168 Self::default()
169 }
170
171 pub fn streaming(events: impl IntoIterator<Item = FakeLlmEvent>) -> Self {
173 Self {
174 turns: vec![FakeLlmTurn::stream(events)],
175 }
176 }
177
178 pub fn erroring(category: ErrorCategory, message: impl Into<String>) -> Self {
180 Self {
181 turns: vec![FakeLlmTurn::error(category, message)],
182 }
183 }
184
185 pub fn push(mut self, turn: FakeLlmTurn) -> Self {
187 self.turns.push(turn);
188 self
189 }
190}
191
192#[derive(Clone, Debug)]
195pub struct FakeLlmCall {
196 pub provider: String,
197 pub model: String,
198 pub system: Option<String>,
199 pub messages: Vec<serde_json::Value>,
200 pub native_tools: Option<Vec<serde_json::Value>>,
201 pub stream: bool,
202}
203
204impl FakeLlmCall {
205 fn from_request(request: &LlmRequestPayload) -> Self {
206 Self {
207 provider: request.provider.clone(),
208 model: request.model.clone(),
209 system: request.system.clone(),
210 messages: request.messages.clone(),
211 native_tools: request.native_tools.clone(),
212 stream: request.stream,
213 }
214 }
215}
216
217thread_local! {
218 static FAKE_LLM_TURNS: RefCell<Vec<FakeLlmTurn>> = const { RefCell::new(Vec::new()) };
219 static FAKE_LLM_CALLS: RefCell<Vec<FakeLlmCall>> = const { RefCell::new(Vec::new()) };
220}
221
222#[must_use = "FakeLlmGuard asserts on drop; bind it to a `_guard` local"]
230pub fn install_fake_llm_script(script: FakeLlmScript) -> FakeLlmGuard {
231 FAKE_LLM_TURNS.with(|turns| {
232 let mut turns = turns.borrow_mut();
233 assert!(
234 turns.is_empty(),
235 "FakeLlmProvider: a script is already installed; drop the previous guard before installing a new one"
236 );
237 *turns = script.turns;
238 });
239 FAKE_LLM_CALLS.with(|calls| calls.borrow_mut().clear());
240 FakeLlmGuard { _priv: () }
241}
242
243pub fn fake_llm_captured_calls() -> Vec<FakeLlmCall> {
245 FAKE_LLM_CALLS.with(|calls| calls.borrow().clone())
246}
247
248#[must_use]
254pub struct FakeLlmGuard {
255 _priv: (),
256}
257
258impl Drop for FakeLlmGuard {
259 fn drop(&mut self) {
260 let remaining = FAKE_LLM_TURNS.with(|turns| std::mem::take(&mut *turns.borrow_mut()));
261 FAKE_LLM_CALLS.with(|calls| calls.borrow_mut().clear());
262 if std::thread::panicking() {
265 return;
266 }
267 assert!(
268 remaining.is_empty(),
269 "FakeLlmProvider script had {} unconsumed turn(s); did the code under test make fewer LLM calls than expected?",
270 remaining.len()
271 );
272 }
273}
274
275fn take_next_turn(request: &LlmRequestPayload) -> Result<FakeLlmTurn, VmError> {
278 FAKE_LLM_CALLS.with(|calls| {
279 calls.borrow_mut().push(FakeLlmCall::from_request(request));
280 });
281 FAKE_LLM_TURNS.with(|turns| {
282 let mut turns = turns.borrow_mut();
283 if turns.is_empty() {
284 Err(VmError::Runtime(
285 "FakeLlmProvider: no script installed (or script exhausted) — install_fake_llm_script() must precede llm_call(provider: \"fake\")".to_string()
286 ))
287 } else {
288 Ok(turns.remove(0))
289 }
290 })
291}
292
293fn fake_error_to_vm_error(err: &FakeLlmError) -> VmError {
296 let message = match err.retry_after_ms {
297 Some(ms) => {
298 let secs = (ms as f64 / 1000.0).max(0.0);
299 let sep = if err.message.is_empty() || err.message.ends_with('\n') {
300 ""
301 } else {
302 "\n"
303 };
304 format!("{}{sep}retry-after: {secs}\n", err.message)
305 }
306 None => err.message.clone(),
307 };
308 VmError::CategorizedError {
309 message,
310 category: err.category.clone(),
311 }
312}
313
314pub(crate) struct FakeLlmProvider;
316
317impl LlmProvider for FakeLlmProvider {
318 fn name(&self) -> &'static str {
319 "fake"
320 }
321
322 fn requires_model(&self) -> bool {
323 false
324 }
325}
326
327impl LlmProviderChat for FakeLlmProvider {
328 fn chat<'a>(
329 &'a self,
330 request: &'a LlmRequestPayload,
331 delta_tx: Option<DeltaSender>,
332 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<LlmResult, VmError>> + 'a>> {
333 Box::pin(self.chat_impl(request, delta_tx))
334 }
335}
336
337impl FakeLlmProvider {
338 pub(crate) fn should_intercept(provider: &str) -> bool {
340 provider == "fake"
341 }
342
343 pub(crate) async fn chat_impl(
344 &self,
345 request: &LlmRequestPayload,
346 delta_tx: Option<DeltaSender>,
347 ) -> Result<LlmResult, VmError> {
348 loop {
349 let turn = take_next_turn(request)?;
350 match turn {
351 FakeLlmTurn::Stalled(duration) => {
352 if !duration.is_zero() {
353 tokio::time::sleep(duration).await;
354 }
355 }
359 FakeLlmTurn::Error(err) => {
360 return Err(fake_error_to_vm_error(&err));
361 }
362 FakeLlmTurn::Stream(events) => {
363 return play_stream(request, events, delta_tx).await;
364 }
365 }
366 }
367 }
368}
369
370async fn play_stream(
371 request: &LlmRequestPayload,
372 events: Vec<FakeLlmEvent>,
373 delta_tx: Option<DeltaSender>,
374) -> Result<LlmResult, VmError> {
375 let mut text = String::new();
376 let mut tool_calls: Vec<serde_json::Value> = Vec::new();
377 let mut blocks: Vec<serde_json::Value> = Vec::new();
378 let mut stop_reason: Option<FakeStopReason> = None;
379 let mut next_tool_index: usize = 1;
380 let mut schema_watch = crate::llm::api::StreamSchemaWatch::from_payload(request);
384
385 for event in events {
386 match event {
387 FakeLlmEvent::Token(chunk) => {
388 if let Some(tx) = delta_tx.as_ref() {
389 let _ = tx.send(chunk.clone());
390 }
391 text.push_str(&chunk);
392 if let Some(watch) = schema_watch.as_mut() {
393 if let Some(abort) = watch.observe(&chunk) {
394 return Err(abort.into_vm_error());
395 }
396 }
397 }
398 FakeLlmEvent::ToolCallDelta {
399 id,
400 name,
401 arguments,
402 } => {
403 let id = if id.is_empty() {
404 let auto = format!("fake_call_{next_tool_index}");
405 next_tool_index += 1;
406 auto
407 } else {
408 id
409 };
410 tool_calls.push(serde_json::json!({
411 "id": id,
412 "type": "tool_call",
413 "name": name,
414 "arguments": arguments,
415 }));
416 blocks.push(serde_json::json!({
417 "type": "tool_call",
418 "id": id,
419 "name": name,
420 "arguments": arguments,
421 "visibility": "internal",
422 }));
423 }
424 FakeLlmEvent::Stall(duration) => {
425 if !duration.is_zero() {
426 tokio::time::sleep(duration).await;
427 }
428 }
429 FakeLlmEvent::Done(reason) => {
430 stop_reason = Some(reason);
431 break;
432 }
433 }
434 }
435
436 if !text.is_empty() {
437 let text_block = serde_json::json!({
440 "type": "output_text",
441 "text": text,
442 "visibility": "public",
443 });
444 blocks.insert(0, text_block);
445 }
446
447 let stop_reason = stop_reason.unwrap_or(if tool_calls.is_empty() {
448 FakeStopReason::EndTurn
449 } else {
450 FakeStopReason::ToolUse
451 });
452
453 Ok(LlmResult {
454 served_fast: false,
455 text,
456 tool_calls,
457 input_tokens: count_input_tokens(&request.messages),
458 output_tokens: 0,
459 cache_read_tokens: 0,
460 cache_write_tokens: 0,
461 cache_supported: true,
462 model: request.model.clone(),
463 provider: "fake".to_string(),
464 thinking: None,
465 thinking_summary: None,
466 stop_reason: Some(stop_reason.as_str().to_string()),
467 blocks,
468 logprobs: Vec::new(),
469 telemetry: ProviderTelemetry::default(),
470 })
471}
472
473fn count_input_tokens(messages: &[serde_json::Value]) -> i64 {
477 fn collect(value: &serde_json::Value, out: &mut String) {
478 match value {
479 serde_json::Value::String(text) => {
480 out.push_str(text);
481 out.push('\n');
482 }
483 serde_json::Value::Array(items) => {
484 for item in items {
485 collect(item, out);
486 }
487 }
488 serde_json::Value::Object(map) => {
489 for value in map.values() {
490 collect(value, out);
491 }
492 }
493 _ => {}
494 }
495 }
496 let mut buf = String::new();
497 for message in messages {
498 collect(message, &mut buf);
499 }
500 buf.len() as i64
501}
502
503#[cfg(test)]
504mod tests {
505 use super::*;
506 use crate::llm::api::{LlmApiMode, ThinkingConfig};
507 use crate::llm::api::{LlmRequestPayload, LlmRouteFallback, OutputFormat};
508
509 fn fake_request() -> LlmRequestPayload {
510 LlmRequestPayload {
511 provider: "fake".to_string(),
512 model: "fake-model".to_string(),
513 region: None,
514 api_key: String::new(),
515 api_mode: LlmApiMode::ChatCompletions,
516 fallback_chain: Vec::new(),
517 route_fallbacks: Vec::<LlmRouteFallback>::new(),
518 messages: vec![serde_json::json!({"role": "user", "content": "hello"})],
519 system: None,
520 max_tokens: 64,
521 temperature: None,
522 top_p: None,
523 top_k: None,
524 logprobs: false,
525 top_logprobs: None,
526 stop: None,
527 seed: None,
528 frequency_penalty: None,
529 presence_penalty: None,
530 fast: false,
531 output_format: OutputFormat::Text,
532 response_format: None,
533 json_schema: None,
534 output_schema: None,
535 schema_stream_abort: false,
536 thinking: ThinkingConfig::Disabled,
537 anthropic_beta_features: Vec::new(),
538 vision: false,
539 native_tools: None,
540 provider_tools: Vec::new(),
541 tool_choice: None,
542 cache: false,
543 timeout: None,
544 stream: true,
545 provider_overrides: None,
546 previous_response_id: None,
547 store: None,
548 background: None,
549 truncation: None,
550 compact: None,
551 include: None,
552 max_tool_calls: None,
553 prefill: None,
554 session_id: None,
555 reminder_lifecycle: Vec::new(),
556 }
557 }
558
559 fn current_thread_runtime() -> tokio::runtime::Runtime {
560 tokio::runtime::Builder::new_current_thread()
561 .enable_all()
562 .start_paused(false)
563 .build()
564 .expect("runtime")
565 }
566
567 #[test]
568 fn streaming_turn_emits_deltas_in_order() {
569 let runtime = current_thread_runtime();
570 let _guard = install_fake_llm_script(FakeLlmScript::streaming(vec![
571 FakeLlmEvent::Token("hello ".into()),
572 FakeLlmEvent::Token("world".into()),
573 FakeLlmEvent::Done(FakeStopReason::EndTurn),
574 ]));
575
576 runtime.block_on(async {
577 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<String>();
578 let result = FakeLlmProvider
579 .chat_impl(&fake_request(), Some(tx))
580 .await
581 .expect("fake call should succeed");
582
583 let mut deltas = Vec::new();
584 while let Ok(delta) = rx.try_recv() {
585 deltas.push(delta);
586 }
587
588 assert_eq!(deltas, vec!["hello ".to_string(), "world".to_string()]);
589 assert_eq!(result.text, "hello world");
590 assert_eq!(result.provider, "fake");
591 assert_eq!(result.stop_reason.as_deref(), Some("end_turn"));
592 assert_eq!(result.blocks.len(), 1);
593 assert_eq!(result.blocks[0]["type"].as_str(), Some("output_text"));
594 assert_eq!(result.blocks[0]["text"].as_str(), Some("hello world"));
595 });
596 assert_eq!(fake_llm_captured_calls().len(), 1);
597 }
598
599 #[test]
600 fn tool_call_deltas_become_tool_calls_and_blocks() {
601 let runtime = current_thread_runtime();
602 let _guard = install_fake_llm_script(FakeLlmScript::streaming(vec![
603 FakeLlmEvent::Token("calling tool".into()),
604 FakeLlmEvent::ToolCallDelta {
605 id: String::new(),
606 name: "search".into(),
607 arguments: serde_json::json!({"q": "harn"}),
608 },
609 FakeLlmEvent::Done(FakeStopReason::ToolUse),
610 ]));
611
612 runtime.block_on(async {
613 let result = FakeLlmProvider
614 .chat_impl(&fake_request(), None)
615 .await
616 .expect("fake call should succeed");
617
618 assert_eq!(result.tool_calls.len(), 1);
619 assert_eq!(result.tool_calls[0]["name"].as_str(), Some("search"));
620 assert_eq!(result.tool_calls[0]["id"].as_str(), Some("fake_call_1"));
621 assert_eq!(
622 result.tool_calls[0]["arguments"]["q"].as_str(),
623 Some("harn")
624 );
625 assert_eq!(result.stop_reason.as_deref(), Some("tool_use"));
626 assert_eq!(result.blocks[0]["type"].as_str(), Some("output_text"));
628 assert_eq!(result.blocks[1]["type"].as_str(), Some("tool_call"));
629 assert_eq!(result.blocks[1]["name"].as_str(), Some("search"));
630 });
631 }
632
633 #[test]
634 fn error_turn_returns_categorized_error() {
635 let runtime = current_thread_runtime();
636 let _guard = install_fake_llm_script(FakeLlmScript::erroring(
637 ErrorCategory::RateLimit,
638 "throttled",
639 ));
640
641 runtime.block_on(async {
642 let err = FakeLlmProvider
643 .chat_impl(&fake_request(), None)
644 .await
645 .expect_err("fake error turn should fail");
646 match err {
647 VmError::CategorizedError { message, category } => {
648 assert_eq!(category, ErrorCategory::RateLimit);
649 assert!(
650 message.contains("throttled"),
651 "error message should pass through: {message}"
652 );
653 }
654 other => panic!("expected CategorizedError, got {other:?}"),
655 }
656 });
657 }
658
659 #[test]
660 fn error_turn_embeds_retry_after_hint() {
661 let runtime = current_thread_runtime();
662 let _guard = install_fake_llm_script(FakeLlmScript::default().push(FakeLlmTurn::Error(
663 FakeLlmError::new(ErrorCategory::RateLimit, "throttled").with_retry_after_ms(2_500),
664 )));
665
666 runtime.block_on(async {
667 let err = FakeLlmProvider
668 .chat_impl(&fake_request(), None)
669 .await
670 .expect_err("fake error turn should fail");
671 let VmError::CategorizedError { message, .. } = err else {
672 panic!("expected CategorizedError");
673 };
674 assert!(
675 message.contains("retry-after: 2.5"),
676 "retry-after hint should be present in synthetic message: {message}"
677 );
678 });
679 }
680
681 #[test]
682 fn stalled_turn_advances_under_paused_clock() {
683 let runtime = tokio::runtime::Builder::new_current_thread()
684 .enable_all()
685 .start_paused(true)
686 .build()
687 .expect("paused runtime");
688 let _guard = install_fake_llm_script(
689 FakeLlmScript::default()
690 .push(FakeLlmTurn::Stalled(Duration::from_mins(1)))
691 .push(FakeLlmTurn::stream(vec![
692 FakeLlmEvent::Token("done".into()),
693 FakeLlmEvent::Done(FakeStopReason::EndTurn),
694 ])),
695 );
696
697 runtime.block_on(async {
698 let request = fake_request();
699 let chat = FakeLlmProvider.chat_impl(&request, None);
700 tokio::pin!(chat);
701
702 let polled = futures::poll!(&mut chat);
706 assert!(
707 matches!(polled, std::task::Poll::Pending),
708 "fake provider should be parked on the stall"
709 );
710
711 tokio::time::advance(Duration::from_mins(1)).await;
712 let result = chat.await.expect("after advance, fake call resolves");
713 assert_eq!(result.text, "done");
714 });
715 }
716
717 #[test]
718 fn multiple_turns_consumed_in_fifo_order() {
719 let runtime = current_thread_runtime();
720 let _guard = install_fake_llm_script(
721 FakeLlmScript::default()
722 .push(FakeLlmTurn::stream(vec![
723 FakeLlmEvent::Token("first".into()),
724 FakeLlmEvent::Done(FakeStopReason::EndTurn),
725 ]))
726 .push(FakeLlmTurn::stream(vec![
727 FakeLlmEvent::Token("second".into()),
728 FakeLlmEvent::Done(FakeStopReason::EndTurn),
729 ])),
730 );
731
732 runtime.block_on(async {
733 let first = FakeLlmProvider
734 .chat_impl(&fake_request(), None)
735 .await
736 .expect("first call");
737 let second = FakeLlmProvider
738 .chat_impl(&fake_request(), None)
739 .await
740 .expect("second call");
741 assert_eq!(first.text, "first");
742 assert_eq!(second.text, "second");
743 });
744
745 let calls = fake_llm_captured_calls();
746 assert_eq!(calls.len(), 2);
747 assert!(calls.iter().all(|c| c.provider == "fake"));
748 }
749
750 #[test]
751 #[should_panic(expected = "no script installed")]
752 fn calling_without_script_panics_with_explanatory_error() {
753 let runtime = current_thread_runtime();
754 runtime
756 .block_on(async {
757 FakeLlmProvider
758 .chat_impl(&fake_request(), None)
759 .await
760 .map_err(|e| e.to_string())
761 })
762 .unwrap();
763 }
764
765 #[test]
766 #[should_panic(expected = "unconsumed turn")]
767 fn drop_guard_asserts_on_unused_turns() {
768 let guard =
769 install_fake_llm_script(FakeLlmScript::default().push(FakeLlmTurn::stream(vec![
770 FakeLlmEvent::Done(FakeStopReason::EndTurn),
771 ])));
772 drop(guard);
774 }
775}