1use std::cell::RefCell;
49use std::time::Duration;
50
51use crate::llm::api::{DeltaSender, LlmRequestPayload, LlmResult, ProviderTelemetry};
52use crate::llm::provider::{LlmProvider, LlmProviderChat};
53use crate::value::{ErrorCategory, VmError};
54
55#[derive(Clone, Debug)]
57pub enum FakeLlmTurn {
58 Stream(Vec<FakeLlmEvent>),
61 Error(FakeLlmError),
64 Stalled(Duration),
68}
69
70impl FakeLlmTurn {
71 pub fn stream(events: impl IntoIterator<Item = FakeLlmEvent>) -> Self {
73 Self::Stream(events.into_iter().collect())
74 }
75
76 pub fn error(category: ErrorCategory, message: impl Into<String>) -> Self {
78 Self::Error(FakeLlmError {
79 category,
80 message: message.into(),
81 retry_after_ms: None,
82 })
83 }
84}
85
86#[derive(Clone, Debug)]
88pub enum FakeLlmEvent {
89 Token(String),
92 ToolCallDelta {
95 id: String,
96 name: String,
97 arguments: serde_json::Value,
98 },
99 Stall(Duration),
102 Done(FakeStopReason),
105}
106
107#[derive(Clone, Debug, PartialEq, Eq)]
109pub enum FakeStopReason {
110 EndTurn,
111 ToolUse,
112 MaxTokens,
113 StopSequence,
114 Custom(String),
115}
116
117impl FakeStopReason {
118 fn as_str(&self) -> &str {
119 match self {
120 Self::EndTurn => "end_turn",
121 Self::ToolUse => "tool_use",
122 Self::MaxTokens => "max_tokens",
123 Self::StopSequence => "stop_sequence",
124 Self::Custom(value) => value.as_str(),
125 }
126 }
127}
128
129#[derive(Clone, Debug)]
131pub struct FakeLlmError {
132 pub category: ErrorCategory,
133 pub message: String,
134 pub retry_after_ms: Option<u64>,
138}
139
140impl FakeLlmError {
141 pub fn new(category: ErrorCategory, message: impl Into<String>) -> Self {
142 Self {
143 category,
144 message: message.into(),
145 retry_after_ms: None,
146 }
147 }
148
149 pub fn with_retry_after_ms(mut self, ms: u64) -> Self {
150 self.retry_after_ms = Some(ms);
151 self
152 }
153}
154
155#[derive(Clone, Debug, Default)]
162pub struct FakeLlmScript {
163 pub turns: Vec<FakeLlmTurn>,
164}
165
166impl FakeLlmScript {
167 pub fn new() -> Self {
168 Self::default()
169 }
170
171 pub fn streaming(events: impl IntoIterator<Item = FakeLlmEvent>) -> Self {
173 Self {
174 turns: vec![FakeLlmTurn::stream(events)],
175 }
176 }
177
178 pub fn erroring(category: ErrorCategory, message: impl Into<String>) -> Self {
180 Self {
181 turns: vec![FakeLlmTurn::error(category, message)],
182 }
183 }
184
185 pub fn push(mut self, turn: FakeLlmTurn) -> Self {
187 self.turns.push(turn);
188 self
189 }
190}
191
192#[derive(Clone, Debug)]
195pub struct FakeLlmCall {
196 pub provider: String,
197 pub model: String,
198 pub system: Option<String>,
199 pub messages: Vec<serde_json::Value>,
200 pub native_tools: Option<Vec<serde_json::Value>>,
201 pub stream: bool,
202}
203
204impl FakeLlmCall {
205 fn from_request(request: &LlmRequestPayload) -> Self {
206 Self {
207 provider: request.provider.clone(),
208 model: request.model.clone(),
209 system: request.system.clone(),
210 messages: request.messages.clone(),
211 native_tools: request.native_tools.clone(),
212 stream: request.stream,
213 }
214 }
215}
216
217thread_local! {
218 static FAKE_LLM_TURNS: RefCell<Vec<FakeLlmTurn>> = const { RefCell::new(Vec::new()) };
219 static FAKE_LLM_CALLS: RefCell<Vec<FakeLlmCall>> = const { RefCell::new(Vec::new()) };
220}
221
222#[must_use = "FakeLlmGuard asserts on drop; bind it to a `_guard` local"]
230pub fn install_fake_llm_script(script: FakeLlmScript) -> FakeLlmGuard {
231 FAKE_LLM_TURNS.with(|turns| {
232 let mut turns = turns.borrow_mut();
233 assert!(
234 turns.is_empty(),
235 "FakeLlmProvider: a script is already installed; drop the previous guard before installing a new one"
236 );
237 *turns = script.turns;
238 });
239 FAKE_LLM_CALLS.with(|calls| calls.borrow_mut().clear());
240 FakeLlmGuard { _priv: () }
241}
242
243pub fn fake_llm_captured_calls() -> Vec<FakeLlmCall> {
245 FAKE_LLM_CALLS.with(|calls| calls.borrow().clone())
246}
247
248#[must_use]
254pub struct FakeLlmGuard {
255 _priv: (),
256}
257
258impl Drop for FakeLlmGuard {
259 fn drop(&mut self) {
260 let remaining = FAKE_LLM_TURNS.with(|turns| std::mem::take(&mut *turns.borrow_mut()));
261 FAKE_LLM_CALLS.with(|calls| calls.borrow_mut().clear());
262 if std::thread::panicking() {
265 return;
266 }
267 assert!(
268 remaining.is_empty(),
269 "FakeLlmProvider script had {} unconsumed turn(s); did the code under test make fewer LLM calls than expected?",
270 remaining.len()
271 );
272 }
273}
274
275fn take_next_turn(request: &LlmRequestPayload) -> Result<FakeLlmTurn, VmError> {
278 FAKE_LLM_CALLS.with(|calls| {
279 calls.borrow_mut().push(FakeLlmCall::from_request(request));
280 });
281 FAKE_LLM_TURNS.with(|turns| {
282 let mut turns = turns.borrow_mut();
283 if turns.is_empty() {
284 Err(VmError::Runtime(
285 "FakeLlmProvider: no script installed (or script exhausted) — install_fake_llm_script() must precede llm_call(provider: \"fake\")".to_string()
286 ))
287 } else {
288 Ok(turns.remove(0))
289 }
290 })
291}
292
293fn fake_error_to_vm_error(err: &FakeLlmError) -> VmError {
296 let message = match err.retry_after_ms {
297 Some(ms) => {
298 let secs = (ms as f64 / 1000.0).max(0.0);
299 let sep = if err.message.is_empty() || err.message.ends_with('\n') {
300 ""
301 } else {
302 "\n"
303 };
304 format!("{}{sep}retry-after: {secs}\n", err.message)
305 }
306 None => err.message.clone(),
307 };
308 VmError::CategorizedError {
309 message,
310 category: err.category.clone(),
311 }
312}
313
314pub(crate) struct FakeLlmProvider;
316
317impl LlmProvider for FakeLlmProvider {
318 fn name(&self) -> &str {
319 "fake"
320 }
321
322 fn requires_model(&self) -> bool {
323 false
324 }
325}
326
327impl LlmProviderChat for FakeLlmProvider {
328 fn chat<'a>(
329 &'a self,
330 request: &'a LlmRequestPayload,
331 delta_tx: Option<DeltaSender>,
332 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<LlmResult, VmError>> + 'a>> {
333 Box::pin(self.chat_impl(request, delta_tx))
334 }
335}
336
337impl FakeLlmProvider {
338 pub(crate) fn should_intercept(provider: &str) -> bool {
340 provider == "fake"
341 }
342
343 pub(crate) async fn chat_impl(
344 &self,
345 request: &LlmRequestPayload,
346 delta_tx: Option<DeltaSender>,
347 ) -> Result<LlmResult, VmError> {
348 loop {
349 let turn = take_next_turn(request)?;
350 match turn {
351 FakeLlmTurn::Stalled(duration) => {
352 if !duration.is_zero() {
353 tokio::time::sleep(duration).await;
354 }
355 }
359 FakeLlmTurn::Error(err) => {
360 return Err(fake_error_to_vm_error(&err));
361 }
362 FakeLlmTurn::Stream(events) => {
363 return play_stream(request, events, delta_tx).await;
364 }
365 }
366 }
367 }
368}
369
370async fn play_stream(
371 request: &LlmRequestPayload,
372 events: Vec<FakeLlmEvent>,
373 delta_tx: Option<DeltaSender>,
374) -> Result<LlmResult, VmError> {
375 let mut text = String::new();
376 let mut tool_calls: Vec<serde_json::Value> = Vec::new();
377 let mut blocks: Vec<serde_json::Value> = Vec::new();
378 let mut stop_reason: Option<FakeStopReason> = None;
379 let mut next_tool_index: usize = 1;
380 let mut schema_watch = crate::llm::api::StreamSchemaWatch::from_payload(request);
384
385 for event in events {
386 match event {
387 FakeLlmEvent::Token(chunk) => {
388 if let Some(tx) = delta_tx.as_ref() {
389 let _ = tx.send(chunk.clone());
390 }
391 text.push_str(&chunk);
392 if let Some(watch) = schema_watch.as_mut() {
393 if let Some(abort) = watch.observe(&chunk) {
394 return Err(abort.into_vm_error());
395 }
396 }
397 }
398 FakeLlmEvent::ToolCallDelta {
399 id,
400 name,
401 arguments,
402 } => {
403 let id = if id.is_empty() {
404 let auto = format!("fake_call_{}", next_tool_index);
405 next_tool_index += 1;
406 auto
407 } else {
408 id
409 };
410 tool_calls.push(serde_json::json!({
411 "id": id,
412 "type": "tool_call",
413 "name": name,
414 "arguments": arguments,
415 }));
416 blocks.push(serde_json::json!({
417 "type": "tool_call",
418 "id": id,
419 "name": name,
420 "arguments": arguments,
421 "visibility": "internal",
422 }));
423 }
424 FakeLlmEvent::Stall(duration) => {
425 if !duration.is_zero() {
426 tokio::time::sleep(duration).await;
427 }
428 }
429 FakeLlmEvent::Done(reason) => {
430 stop_reason = Some(reason);
431 break;
432 }
433 }
434 }
435
436 if !text.is_empty() {
437 let text_block = serde_json::json!({
440 "type": "output_text",
441 "text": text,
442 "visibility": "public",
443 });
444 blocks.insert(0, text_block);
445 }
446
447 let stop_reason = stop_reason.unwrap_or(if tool_calls.is_empty() {
448 FakeStopReason::EndTurn
449 } else {
450 FakeStopReason::ToolUse
451 });
452
453 Ok(LlmResult {
454 text,
455 tool_calls,
456 input_tokens: count_input_tokens(&request.messages),
457 output_tokens: 0,
458 cache_read_tokens: 0,
459 cache_write_tokens: 0,
460 model: request.model.clone(),
461 provider: "fake".to_string(),
462 thinking: None,
463 thinking_summary: None,
464 stop_reason: Some(stop_reason.as_str().to_string()),
465 blocks,
466 logprobs: Vec::new(),
467 telemetry: ProviderTelemetry::default(),
468 })
469}
470
471fn count_input_tokens(messages: &[serde_json::Value]) -> i64 {
475 fn collect(value: &serde_json::Value, out: &mut String) {
476 match value {
477 serde_json::Value::String(text) => {
478 out.push_str(text);
479 out.push('\n');
480 }
481 serde_json::Value::Array(items) => {
482 for item in items {
483 collect(item, out);
484 }
485 }
486 serde_json::Value::Object(map) => {
487 for value in map.values() {
488 collect(value, out);
489 }
490 }
491 _ => {}
492 }
493 }
494 let mut buf = String::new();
495 for message in messages {
496 collect(message, &mut buf);
497 }
498 buf.len() as i64
499}
500
501#[cfg(test)]
502mod tests {
503 use super::*;
504 use crate::llm::api::{LlmApiMode, ThinkingConfig};
505 use crate::llm::api::{LlmRequestPayload, LlmRouteFallback, OutputFormat};
506
507 fn fake_request() -> LlmRequestPayload {
508 LlmRequestPayload {
509 provider: "fake".to_string(),
510 model: "fake-model".to_string(),
511 api_key: String::new(),
512 api_mode: LlmApiMode::ChatCompletions,
513 fallback_chain: Vec::new(),
514 route_fallbacks: Vec::<LlmRouteFallback>::new(),
515 messages: vec![serde_json::json!({"role": "user", "content": "hello"})],
516 system: None,
517 max_tokens: 64,
518 temperature: None,
519 top_p: None,
520 top_k: None,
521 logprobs: false,
522 top_logprobs: None,
523 stop: None,
524 seed: None,
525 frequency_penalty: None,
526 presence_penalty: None,
527 output_format: OutputFormat::Text,
528 response_format: None,
529 json_schema: None,
530 output_schema: None,
531 schema_stream_abort: false,
532 thinking: ThinkingConfig::Disabled,
533 anthropic_beta_features: Vec::new(),
534 vision: false,
535 native_tools: None,
536 provider_tools: Vec::new(),
537 tool_choice: None,
538 cache: false,
539 timeout: None,
540 stream: true,
541 provider_overrides: None,
542 previous_response_id: None,
543 store: None,
544 background: None,
545 truncation: None,
546 compact: None,
547 include: None,
548 max_tool_calls: None,
549 prefill: None,
550 session_id: None,
551 reminder_lifecycle: Vec::new(),
552 }
553 }
554
555 fn current_thread_runtime() -> tokio::runtime::Runtime {
556 tokio::runtime::Builder::new_current_thread()
557 .enable_all()
558 .start_paused(false)
559 .build()
560 .expect("runtime")
561 }
562
563 #[test]
564 fn streaming_turn_emits_deltas_in_order() {
565 let runtime = current_thread_runtime();
566 let _guard = install_fake_llm_script(FakeLlmScript::streaming(vec![
567 FakeLlmEvent::Token("hello ".into()),
568 FakeLlmEvent::Token("world".into()),
569 FakeLlmEvent::Done(FakeStopReason::EndTurn),
570 ]));
571
572 runtime.block_on(async {
573 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<String>();
574 let result = FakeLlmProvider
575 .chat_impl(&fake_request(), Some(tx))
576 .await
577 .expect("fake call should succeed");
578
579 let mut deltas = Vec::new();
580 while let Ok(delta) = rx.try_recv() {
581 deltas.push(delta);
582 }
583
584 assert_eq!(deltas, vec!["hello ".to_string(), "world".to_string()]);
585 assert_eq!(result.text, "hello world");
586 assert_eq!(result.provider, "fake");
587 assert_eq!(result.stop_reason.as_deref(), Some("end_turn"));
588 assert_eq!(result.blocks.len(), 1);
589 assert_eq!(result.blocks[0]["type"].as_str(), Some("output_text"));
590 assert_eq!(result.blocks[0]["text"].as_str(), Some("hello world"));
591 });
592 assert_eq!(fake_llm_captured_calls().len(), 1);
593 }
594
595 #[test]
596 fn tool_call_deltas_become_tool_calls_and_blocks() {
597 let runtime = current_thread_runtime();
598 let _guard = install_fake_llm_script(FakeLlmScript::streaming(vec![
599 FakeLlmEvent::Token("calling tool".into()),
600 FakeLlmEvent::ToolCallDelta {
601 id: String::new(),
602 name: "search".into(),
603 arguments: serde_json::json!({"q": "harn"}),
604 },
605 FakeLlmEvent::Done(FakeStopReason::ToolUse),
606 ]));
607
608 runtime.block_on(async {
609 let result = FakeLlmProvider
610 .chat_impl(&fake_request(), None)
611 .await
612 .expect("fake call should succeed");
613
614 assert_eq!(result.tool_calls.len(), 1);
615 assert_eq!(result.tool_calls[0]["name"].as_str(), Some("search"));
616 assert_eq!(result.tool_calls[0]["id"].as_str(), Some("fake_call_1"));
617 assert_eq!(
618 result.tool_calls[0]["arguments"]["q"].as_str(),
619 Some("harn")
620 );
621 assert_eq!(result.stop_reason.as_deref(), Some("tool_use"));
622 assert_eq!(result.blocks[0]["type"].as_str(), Some("output_text"));
624 assert_eq!(result.blocks[1]["type"].as_str(), Some("tool_call"));
625 assert_eq!(result.blocks[1]["name"].as_str(), Some("search"));
626 });
627 }
628
629 #[test]
630 fn error_turn_returns_categorized_error() {
631 let runtime = current_thread_runtime();
632 let _guard = install_fake_llm_script(FakeLlmScript::erroring(
633 ErrorCategory::RateLimit,
634 "throttled",
635 ));
636
637 runtime.block_on(async {
638 let err = FakeLlmProvider
639 .chat_impl(&fake_request(), None)
640 .await
641 .expect_err("fake error turn should fail");
642 match err {
643 VmError::CategorizedError { message, category } => {
644 assert_eq!(category, ErrorCategory::RateLimit);
645 assert!(
646 message.contains("throttled"),
647 "error message should pass through: {message}"
648 );
649 }
650 other => panic!("expected CategorizedError, got {other:?}"),
651 }
652 });
653 }
654
655 #[test]
656 fn error_turn_embeds_retry_after_hint() {
657 let runtime = current_thread_runtime();
658 let _guard = install_fake_llm_script(FakeLlmScript::default().push(FakeLlmTurn::Error(
659 FakeLlmError::new(ErrorCategory::RateLimit, "throttled").with_retry_after_ms(2_500),
660 )));
661
662 runtime.block_on(async {
663 let err = FakeLlmProvider
664 .chat_impl(&fake_request(), None)
665 .await
666 .expect_err("fake error turn should fail");
667 let VmError::CategorizedError { message, .. } = err else {
668 panic!("expected CategorizedError");
669 };
670 assert!(
671 message.contains("retry-after: 2.5"),
672 "retry-after hint should be present in synthetic message: {message}"
673 );
674 });
675 }
676
677 #[test]
678 fn stalled_turn_advances_under_paused_clock() {
679 let runtime = tokio::runtime::Builder::new_current_thread()
680 .enable_all()
681 .start_paused(true)
682 .build()
683 .expect("paused runtime");
684 let _guard = install_fake_llm_script(
685 FakeLlmScript::default()
686 .push(FakeLlmTurn::Stalled(Duration::from_secs(60)))
687 .push(FakeLlmTurn::stream(vec![
688 FakeLlmEvent::Token("done".into()),
689 FakeLlmEvent::Done(FakeStopReason::EndTurn),
690 ])),
691 );
692
693 runtime.block_on(async {
694 let request = fake_request();
695 let chat = FakeLlmProvider.chat_impl(&request, None);
696 tokio::pin!(chat);
697
698 let polled = futures::poll!(&mut chat);
702 assert!(
703 matches!(polled, std::task::Poll::Pending),
704 "fake provider should be parked on the stall"
705 );
706
707 tokio::time::advance(Duration::from_secs(60)).await;
708 let result = chat.await.expect("after advance, fake call resolves");
709 assert_eq!(result.text, "done");
710 });
711 }
712
713 #[test]
714 fn multiple_turns_consumed_in_fifo_order() {
715 let runtime = current_thread_runtime();
716 let _guard = install_fake_llm_script(
717 FakeLlmScript::default()
718 .push(FakeLlmTurn::stream(vec![
719 FakeLlmEvent::Token("first".into()),
720 FakeLlmEvent::Done(FakeStopReason::EndTurn),
721 ]))
722 .push(FakeLlmTurn::stream(vec![
723 FakeLlmEvent::Token("second".into()),
724 FakeLlmEvent::Done(FakeStopReason::EndTurn),
725 ])),
726 );
727
728 runtime.block_on(async {
729 let first = FakeLlmProvider
730 .chat_impl(&fake_request(), None)
731 .await
732 .expect("first call");
733 let second = FakeLlmProvider
734 .chat_impl(&fake_request(), None)
735 .await
736 .expect("second call");
737 assert_eq!(first.text, "first");
738 assert_eq!(second.text, "second");
739 });
740
741 let calls = fake_llm_captured_calls();
742 assert_eq!(calls.len(), 2);
743 assert!(calls.iter().all(|c| c.provider == "fake"));
744 }
745
746 #[test]
747 #[should_panic(expected = "no script installed")]
748 fn calling_without_script_panics_with_explanatory_error() {
749 let runtime = current_thread_runtime();
750 runtime
752 .block_on(async {
753 FakeLlmProvider
754 .chat_impl(&fake_request(), None)
755 .await
756 .map_err(|e| e.to_string())
757 })
758 .unwrap();
759 }
760
761 #[test]
762 #[should_panic(expected = "unconsumed turn")]
763 fn drop_guard_asserts_on_unused_turns() {
764 let guard =
765 install_fake_llm_script(FakeLlmScript::default().push(FakeLlmTurn::stream(vec![
766 FakeLlmEvent::Done(FakeStopReason::EndTurn),
767 ])));
768 drop(guard);
770 }
771}