1use std::cell::RefCell;
49use std::time::Duration;
50
51use crate::llm::api::{DeltaSender, LlmRequestPayload, LlmResult, ProviderTelemetry};
52use crate::llm::provider::{LlmProvider, LlmProviderChat};
53use crate::value::{ErrorCategory, VmError};
54
55#[derive(Clone, Debug)]
57pub enum FakeLlmTurn {
58 Stream(Vec<FakeLlmEvent>),
61 Error(FakeLlmError),
64 Stalled(Duration),
68}
69
70impl FakeLlmTurn {
71 pub fn stream(events: impl IntoIterator<Item = FakeLlmEvent>) -> Self {
73 Self::Stream(events.into_iter().collect())
74 }
75
76 pub fn error(category: ErrorCategory, message: impl Into<String>) -> Self {
78 Self::Error(FakeLlmError {
79 category,
80 message: message.into(),
81 retry_after_ms: None,
82 })
83 }
84}
85
86#[derive(Clone, Debug)]
88pub enum FakeLlmEvent {
89 Token(String),
92 ToolCallDelta {
95 id: String,
96 name: String,
97 arguments: serde_json::Value,
98 },
99 Stall(Duration),
102 Done(FakeStopReason),
105}
106
107#[derive(Clone, Debug, PartialEq, Eq)]
109pub enum FakeStopReason {
110 EndTurn,
111 ToolUse,
112 MaxTokens,
113 StopSequence,
114 Custom(String),
115}
116
117impl FakeStopReason {
118 fn as_str(&self) -> &str {
119 match self {
120 Self::EndTurn => "end_turn",
121 Self::ToolUse => "tool_use",
122 Self::MaxTokens => "max_tokens",
123 Self::StopSequence => "stop_sequence",
124 Self::Custom(value) => value.as_str(),
125 }
126 }
127}
128
129#[derive(Clone, Debug)]
131pub struct FakeLlmError {
132 pub category: ErrorCategory,
133 pub message: String,
134 pub retry_after_ms: Option<u64>,
138}
139
140impl FakeLlmError {
141 pub fn new(category: ErrorCategory, message: impl Into<String>) -> Self {
142 Self {
143 category,
144 message: message.into(),
145 retry_after_ms: None,
146 }
147 }
148
149 pub fn with_retry_after_ms(mut self, ms: u64) -> Self {
150 self.retry_after_ms = Some(ms);
151 self
152 }
153}
154
155#[derive(Clone, Debug, Default)]
162pub struct FakeLlmScript {
163 pub turns: Vec<FakeLlmTurn>,
164}
165
166impl FakeLlmScript {
167 pub fn new() -> Self {
168 Self::default()
169 }
170
171 pub fn streaming(events: impl IntoIterator<Item = FakeLlmEvent>) -> Self {
173 Self {
174 turns: vec![FakeLlmTurn::stream(events)],
175 }
176 }
177
178 pub fn erroring(category: ErrorCategory, message: impl Into<String>) -> Self {
180 Self {
181 turns: vec![FakeLlmTurn::error(category, message)],
182 }
183 }
184
185 pub fn push(mut self, turn: FakeLlmTurn) -> Self {
187 self.turns.push(turn);
188 self
189 }
190}
191
192#[derive(Clone, Debug)]
195pub struct FakeLlmCall {
196 pub provider: String,
197 pub model: String,
198 pub system: Option<String>,
199 pub messages: Vec<serde_json::Value>,
200 pub native_tools: Option<Vec<serde_json::Value>>,
201 pub stream: bool,
202}
203
204impl FakeLlmCall {
205 fn from_request(request: &LlmRequestPayload) -> Self {
206 Self {
207 provider: request.provider.clone(),
208 model: request.model.clone(),
209 system: request.system.clone(),
210 messages: request.messages.clone(),
211 native_tools: request.native_tools.clone(),
212 stream: request.stream,
213 }
214 }
215}
216
217thread_local! {
218 static FAKE_LLM_TURNS: RefCell<Vec<FakeLlmTurn>> = const { RefCell::new(Vec::new()) };
219 static FAKE_LLM_CALLS: RefCell<Vec<FakeLlmCall>> = const { RefCell::new(Vec::new()) };
220}
221
222#[must_use = "FakeLlmGuard asserts on drop; bind it to a `_guard` local"]
230pub fn install_fake_llm_script(script: FakeLlmScript) -> FakeLlmGuard {
231 FAKE_LLM_TURNS.with(|turns| {
232 let mut turns = turns.borrow_mut();
233 assert!(
234 turns.is_empty(),
235 "FakeLlmProvider: a script is already installed; drop the previous guard before installing a new one"
236 );
237 *turns = script.turns;
238 });
239 FAKE_LLM_CALLS.with(|calls| calls.borrow_mut().clear());
240 FakeLlmGuard { _priv: () }
241}
242
243pub fn fake_llm_captured_calls() -> Vec<FakeLlmCall> {
245 FAKE_LLM_CALLS.with(|calls| calls.borrow().clone())
246}
247
248#[must_use]
254pub struct FakeLlmGuard {
255 _priv: (),
256}
257
258impl Drop for FakeLlmGuard {
259 fn drop(&mut self) {
260 let remaining = FAKE_LLM_TURNS.with(|turns| std::mem::take(&mut *turns.borrow_mut()));
261 FAKE_LLM_CALLS.with(|calls| calls.borrow_mut().clear());
262 if std::thread::panicking() {
265 return;
266 }
267 assert!(
268 remaining.is_empty(),
269 "FakeLlmProvider script had {} unconsumed turn(s); did the code under test make fewer LLM calls than expected?",
270 remaining.len()
271 );
272 }
273}
274
275fn take_next_turn(request: &LlmRequestPayload) -> Result<FakeLlmTurn, VmError> {
278 FAKE_LLM_CALLS.with(|calls| {
279 calls.borrow_mut().push(FakeLlmCall::from_request(request));
280 });
281 FAKE_LLM_TURNS.with(|turns| {
282 let mut turns = turns.borrow_mut();
283 if turns.is_empty() {
284 Err(VmError::Runtime(
285 "FakeLlmProvider: no script installed (or script exhausted) — install_fake_llm_script() must precede llm_call(provider: \"fake\")".to_string()
286 ))
287 } else {
288 Ok(turns.remove(0))
289 }
290 })
291}
292
293fn fake_error_to_vm_error(err: &FakeLlmError) -> VmError {
296 let message = match err.retry_after_ms {
297 Some(ms) => {
298 let secs = (ms as f64 / 1000.0).max(0.0);
299 let sep = if err.message.is_empty() || err.message.ends_with('\n') {
300 ""
301 } else {
302 "\n"
303 };
304 format!("{}{sep}retry-after: {secs}\n", err.message)
305 }
306 None => err.message.clone(),
307 };
308 VmError::CategorizedError {
309 message,
310 category: err.category.clone(),
311 }
312}
313
314pub(crate) struct FakeLlmProvider;
316
317impl LlmProvider for FakeLlmProvider {
318 fn name(&self) -> &'static str {
319 "fake"
320 }
321
322 fn requires_model(&self) -> bool {
323 false
324 }
325}
326
327impl LlmProviderChat for FakeLlmProvider {
328 fn chat<'a>(
329 &'a self,
330 request: &'a LlmRequestPayload,
331 delta_tx: Option<DeltaSender>,
332 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<LlmResult, VmError>> + 'a>> {
333 Box::pin(self.chat_impl(request, delta_tx))
334 }
335}
336
337impl FakeLlmProvider {
338 pub(crate) fn should_intercept(provider: &str) -> bool {
340 provider == "fake"
341 }
342
343 pub(crate) async fn chat_impl(
344 &self,
345 request: &LlmRequestPayload,
346 delta_tx: Option<DeltaSender>,
347 ) -> Result<LlmResult, VmError> {
348 loop {
349 let turn = take_next_turn(request)?;
350 match turn {
351 FakeLlmTurn::Stalled(duration) => {
352 if !duration.is_zero() {
353 tokio::time::sleep(duration).await;
354 }
355 }
359 FakeLlmTurn::Error(err) => {
360 return Err(fake_error_to_vm_error(&err));
361 }
362 FakeLlmTurn::Stream(events) => {
363 return play_stream(request, events, delta_tx).await;
364 }
365 }
366 }
367 }
368}
369
370async fn play_stream(
371 request: &LlmRequestPayload,
372 events: Vec<FakeLlmEvent>,
373 delta_tx: Option<DeltaSender>,
374) -> Result<LlmResult, VmError> {
375 let mut text = String::new();
376 let mut tool_calls: Vec<serde_json::Value> = Vec::new();
377 let mut blocks: Vec<serde_json::Value> = Vec::new();
378 let mut stop_reason: Option<FakeStopReason> = None;
379 let mut next_tool_index: usize = 1;
380 let mut schema_watch = crate::llm::api::StreamSchemaWatch::from_payload(request);
384
385 for event in events {
386 match event {
387 FakeLlmEvent::Token(chunk) => {
388 if let Some(tx) = delta_tx.as_ref() {
389 let _ = tx.send(chunk.clone());
390 }
391 text.push_str(&chunk);
392 if let Some(watch) = schema_watch.as_mut() {
393 if let Some(abort) = watch.observe(&chunk) {
394 return Err(abort.into_vm_error());
395 }
396 }
397 }
398 FakeLlmEvent::ToolCallDelta {
399 id,
400 name,
401 arguments,
402 } => {
403 let id = if id.is_empty() {
404 let auto = format!("fake_call_{next_tool_index}");
405 next_tool_index += 1;
406 auto
407 } else {
408 id
409 };
410 tool_calls.push(serde_json::json!({
411 "id": id,
412 "type": "tool_call",
413 "name": name,
414 "arguments": arguments,
415 }));
416 blocks.push(serde_json::json!({
417 "type": "tool_call",
418 "id": id,
419 "name": name,
420 "arguments": arguments,
421 "visibility": "internal",
422 }));
423 }
424 FakeLlmEvent::Stall(duration) => {
425 if !duration.is_zero() {
426 tokio::time::sleep(duration).await;
427 }
428 }
429 FakeLlmEvent::Done(reason) => {
430 stop_reason = Some(reason);
431 break;
432 }
433 }
434 }
435
436 if !text.is_empty() {
437 let text_block = serde_json::json!({
440 "type": "output_text",
441 "text": text,
442 "visibility": "public",
443 });
444 blocks.insert(0, text_block);
445 }
446
447 let stop_reason = stop_reason.unwrap_or(if tool_calls.is_empty() {
448 FakeStopReason::EndTurn
449 } else {
450 FakeStopReason::ToolUse
451 });
452
453 Ok(LlmResult {
454 served_fast: false,
455 text,
456 tool_calls,
457 input_tokens: count_input_tokens(&request.messages),
458 output_tokens: 0,
459 cache_read_tokens: 0,
460 cache_write_tokens: 0,
461 cache_supported: true,
462 model: request.model.clone(),
463 provider: "fake".to_string(),
464 thinking: None,
465 thinking_summary: None,
466 stop_reason: Some(stop_reason.as_str().to_string()),
467 blocks,
468 logprobs: Vec::new(),
469 telemetry: ProviderTelemetry::default(),
470 })
471}
472
473fn count_input_tokens(messages: &[serde_json::Value]) -> i64 {
477 fn collect(value: &serde_json::Value, out: &mut String) {
478 match value {
479 serde_json::Value::String(text) => {
480 out.push_str(text);
481 out.push('\n');
482 }
483 serde_json::Value::Array(items) => {
484 for item in items {
485 collect(item, out);
486 }
487 }
488 serde_json::Value::Object(map) => {
489 for value in map.values() {
490 collect(value, out);
491 }
492 }
493 _ => {}
494 }
495 }
496 let mut buf = String::new();
497 for message in messages {
498 collect(message, &mut buf);
499 }
500 buf.len() as i64
501}
502
503#[cfg(test)]
504mod tests {
505 use super::*;
506 use crate::llm::api::{LlmApiMode, ThinkingConfig};
507 use crate::llm::api::{LlmRequestPayload, LlmRouteFallback, OutputFormat};
508
509 fn fake_request() -> LlmRequestPayload {
510 LlmRequestPayload {
511 provider: "fake".to_string(),
512 model: "fake-model".to_string(),
513 api_key: String::new(),
514 api_mode: LlmApiMode::ChatCompletions,
515 fallback_chain: Vec::new(),
516 route_fallbacks: Vec::<LlmRouteFallback>::new(),
517 messages: vec![serde_json::json!({"role": "user", "content": "hello"})],
518 system: None,
519 max_tokens: 64,
520 temperature: None,
521 top_p: None,
522 top_k: None,
523 logprobs: false,
524 top_logprobs: None,
525 stop: None,
526 seed: None,
527 frequency_penalty: None,
528 presence_penalty: None,
529 fast: false,
530 output_format: OutputFormat::Text,
531 response_format: None,
532 json_schema: None,
533 output_schema: None,
534 schema_stream_abort: false,
535 thinking: ThinkingConfig::Disabled,
536 anthropic_beta_features: Vec::new(),
537 vision: false,
538 native_tools: None,
539 provider_tools: Vec::new(),
540 tool_choice: None,
541 cache: false,
542 timeout: None,
543 stream: true,
544 provider_overrides: None,
545 previous_response_id: None,
546 store: None,
547 background: None,
548 truncation: None,
549 compact: None,
550 include: None,
551 max_tool_calls: None,
552 prefill: None,
553 session_id: None,
554 reminder_lifecycle: Vec::new(),
555 }
556 }
557
558 fn current_thread_runtime() -> tokio::runtime::Runtime {
559 tokio::runtime::Builder::new_current_thread()
560 .enable_all()
561 .start_paused(false)
562 .build()
563 .expect("runtime")
564 }
565
566 #[test]
567 fn streaming_turn_emits_deltas_in_order() {
568 let runtime = current_thread_runtime();
569 let _guard = install_fake_llm_script(FakeLlmScript::streaming(vec![
570 FakeLlmEvent::Token("hello ".into()),
571 FakeLlmEvent::Token("world".into()),
572 FakeLlmEvent::Done(FakeStopReason::EndTurn),
573 ]));
574
575 runtime.block_on(async {
576 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<String>();
577 let result = FakeLlmProvider
578 .chat_impl(&fake_request(), Some(tx))
579 .await
580 .expect("fake call should succeed");
581
582 let mut deltas = Vec::new();
583 while let Ok(delta) = rx.try_recv() {
584 deltas.push(delta);
585 }
586
587 assert_eq!(deltas, vec!["hello ".to_string(), "world".to_string()]);
588 assert_eq!(result.text, "hello world");
589 assert_eq!(result.provider, "fake");
590 assert_eq!(result.stop_reason.as_deref(), Some("end_turn"));
591 assert_eq!(result.blocks.len(), 1);
592 assert_eq!(result.blocks[0]["type"].as_str(), Some("output_text"));
593 assert_eq!(result.blocks[0]["text"].as_str(), Some("hello world"));
594 });
595 assert_eq!(fake_llm_captured_calls().len(), 1);
596 }
597
598 #[test]
599 fn tool_call_deltas_become_tool_calls_and_blocks() {
600 let runtime = current_thread_runtime();
601 let _guard = install_fake_llm_script(FakeLlmScript::streaming(vec![
602 FakeLlmEvent::Token("calling tool".into()),
603 FakeLlmEvent::ToolCallDelta {
604 id: String::new(),
605 name: "search".into(),
606 arguments: serde_json::json!({"q": "harn"}),
607 },
608 FakeLlmEvent::Done(FakeStopReason::ToolUse),
609 ]));
610
611 runtime.block_on(async {
612 let result = FakeLlmProvider
613 .chat_impl(&fake_request(), None)
614 .await
615 .expect("fake call should succeed");
616
617 assert_eq!(result.tool_calls.len(), 1);
618 assert_eq!(result.tool_calls[0]["name"].as_str(), Some("search"));
619 assert_eq!(result.tool_calls[0]["id"].as_str(), Some("fake_call_1"));
620 assert_eq!(
621 result.tool_calls[0]["arguments"]["q"].as_str(),
622 Some("harn")
623 );
624 assert_eq!(result.stop_reason.as_deref(), Some("tool_use"));
625 assert_eq!(result.blocks[0]["type"].as_str(), Some("output_text"));
627 assert_eq!(result.blocks[1]["type"].as_str(), Some("tool_call"));
628 assert_eq!(result.blocks[1]["name"].as_str(), Some("search"));
629 });
630 }
631
632 #[test]
633 fn error_turn_returns_categorized_error() {
634 let runtime = current_thread_runtime();
635 let _guard = install_fake_llm_script(FakeLlmScript::erroring(
636 ErrorCategory::RateLimit,
637 "throttled",
638 ));
639
640 runtime.block_on(async {
641 let err = FakeLlmProvider
642 .chat_impl(&fake_request(), None)
643 .await
644 .expect_err("fake error turn should fail");
645 match err {
646 VmError::CategorizedError { message, category } => {
647 assert_eq!(category, ErrorCategory::RateLimit);
648 assert!(
649 message.contains("throttled"),
650 "error message should pass through: {message}"
651 );
652 }
653 other => panic!("expected CategorizedError, got {other:?}"),
654 }
655 });
656 }
657
658 #[test]
659 fn error_turn_embeds_retry_after_hint() {
660 let runtime = current_thread_runtime();
661 let _guard = install_fake_llm_script(FakeLlmScript::default().push(FakeLlmTurn::Error(
662 FakeLlmError::new(ErrorCategory::RateLimit, "throttled").with_retry_after_ms(2_500),
663 )));
664
665 runtime.block_on(async {
666 let err = FakeLlmProvider
667 .chat_impl(&fake_request(), None)
668 .await
669 .expect_err("fake error turn should fail");
670 let VmError::CategorizedError { message, .. } = err else {
671 panic!("expected CategorizedError");
672 };
673 assert!(
674 message.contains("retry-after: 2.5"),
675 "retry-after hint should be present in synthetic message: {message}"
676 );
677 });
678 }
679
680 #[test]
681 fn stalled_turn_advances_under_paused_clock() {
682 let runtime = tokio::runtime::Builder::new_current_thread()
683 .enable_all()
684 .start_paused(true)
685 .build()
686 .expect("paused runtime");
687 let _guard = install_fake_llm_script(
688 FakeLlmScript::default()
689 .push(FakeLlmTurn::Stalled(Duration::from_mins(1)))
690 .push(FakeLlmTurn::stream(vec![
691 FakeLlmEvent::Token("done".into()),
692 FakeLlmEvent::Done(FakeStopReason::EndTurn),
693 ])),
694 );
695
696 runtime.block_on(async {
697 let request = fake_request();
698 let chat = FakeLlmProvider.chat_impl(&request, None);
699 tokio::pin!(chat);
700
701 let polled = futures::poll!(&mut chat);
705 assert!(
706 matches!(polled, std::task::Poll::Pending),
707 "fake provider should be parked on the stall"
708 );
709
710 tokio::time::advance(Duration::from_mins(1)).await;
711 let result = chat.await.expect("after advance, fake call resolves");
712 assert_eq!(result.text, "done");
713 });
714 }
715
716 #[test]
717 fn multiple_turns_consumed_in_fifo_order() {
718 let runtime = current_thread_runtime();
719 let _guard = install_fake_llm_script(
720 FakeLlmScript::default()
721 .push(FakeLlmTurn::stream(vec![
722 FakeLlmEvent::Token("first".into()),
723 FakeLlmEvent::Done(FakeStopReason::EndTurn),
724 ]))
725 .push(FakeLlmTurn::stream(vec![
726 FakeLlmEvent::Token("second".into()),
727 FakeLlmEvent::Done(FakeStopReason::EndTurn),
728 ])),
729 );
730
731 runtime.block_on(async {
732 let first = FakeLlmProvider
733 .chat_impl(&fake_request(), None)
734 .await
735 .expect("first call");
736 let second = FakeLlmProvider
737 .chat_impl(&fake_request(), None)
738 .await
739 .expect("second call");
740 assert_eq!(first.text, "first");
741 assert_eq!(second.text, "second");
742 });
743
744 let calls = fake_llm_captured_calls();
745 assert_eq!(calls.len(), 2);
746 assert!(calls.iter().all(|c| c.provider == "fake"));
747 }
748
749 #[test]
750 #[should_panic(expected = "no script installed")]
751 fn calling_without_script_panics_with_explanatory_error() {
752 let runtime = current_thread_runtime();
753 runtime
755 .block_on(async {
756 FakeLlmProvider
757 .chat_impl(&fake_request(), None)
758 .await
759 .map_err(|e| e.to_string())
760 })
761 .unwrap();
762 }
763
764 #[test]
765 #[should_panic(expected = "unconsumed turn")]
766 fn drop_guard_asserts_on_unused_turns() {
767 let guard =
768 install_fake_llm_script(FakeLlmScript::default().push(FakeLlmTurn::stream(vec![
769 FakeLlmEvent::Done(FakeStopReason::EndTurn),
770 ])));
771 drop(guard);
773 }
774}