1use std::cell::RefCell;
49use std::time::Duration;
50
51use crate::llm::api::{DeltaSender, LlmRequestPayload, LlmResult, ProviderTelemetry};
52use crate::llm::provider::{LlmProvider, LlmProviderChat};
53use crate::value::{ErrorCategory, VmError};
54
55#[derive(Clone, Debug)]
57pub enum FakeLlmTurn {
58 Stream(Vec<FakeLlmEvent>),
61 Error(FakeLlmError),
64 Stalled(Duration),
68}
69
70impl FakeLlmTurn {
71 pub fn stream(events: impl IntoIterator<Item = FakeLlmEvent>) -> Self {
73 Self::Stream(events.into_iter().collect())
74 }
75
76 pub fn error(category: ErrorCategory, message: impl Into<String>) -> Self {
78 Self::Error(FakeLlmError {
79 category,
80 message: message.into(),
81 retry_after_ms: None,
82 })
83 }
84}
85
86#[derive(Clone, Debug)]
88pub enum FakeLlmEvent {
89 Token(String),
92 ToolCallDelta {
95 id: String,
96 name: String,
97 arguments: serde_json::Value,
98 },
99 Stall(Duration),
102 Done(FakeStopReason),
105}
106
107#[derive(Clone, Debug, PartialEq, Eq)]
109pub enum FakeStopReason {
110 EndTurn,
111 ToolUse,
112 MaxTokens,
113 StopSequence,
114 Custom(String),
115}
116
117impl FakeStopReason {
118 fn as_str(&self) -> &str {
119 match self {
120 Self::EndTurn => "end_turn",
121 Self::ToolUse => "tool_use",
122 Self::MaxTokens => "max_tokens",
123 Self::StopSequence => "stop_sequence",
124 Self::Custom(value) => value.as_str(),
125 }
126 }
127}
128
129#[derive(Clone, Debug)]
131pub struct FakeLlmError {
132 pub category: ErrorCategory,
133 pub message: String,
134 pub retry_after_ms: Option<u64>,
138}
139
140impl FakeLlmError {
141 pub fn new(category: ErrorCategory, message: impl Into<String>) -> Self {
142 Self {
143 category,
144 message: message.into(),
145 retry_after_ms: None,
146 }
147 }
148
149 pub fn with_retry_after_ms(mut self, ms: u64) -> Self {
150 self.retry_after_ms = Some(ms);
151 self
152 }
153}
154
155#[derive(Clone, Debug, Default)]
162pub struct FakeLlmScript {
163 pub turns: Vec<FakeLlmTurn>,
164}
165
166impl FakeLlmScript {
167 pub fn new() -> Self {
168 Self::default()
169 }
170
171 pub fn streaming(events: impl IntoIterator<Item = FakeLlmEvent>) -> Self {
173 Self {
174 turns: vec![FakeLlmTurn::stream(events)],
175 }
176 }
177
178 pub fn erroring(category: ErrorCategory, message: impl Into<String>) -> Self {
180 Self {
181 turns: vec![FakeLlmTurn::error(category, message)],
182 }
183 }
184
185 pub fn push(mut self, turn: FakeLlmTurn) -> Self {
187 self.turns.push(turn);
188 self
189 }
190}
191
192#[derive(Clone, Debug)]
195pub struct FakeLlmCall {
196 pub provider: String,
197 pub model: String,
198 pub system: Option<String>,
199 pub messages: Vec<serde_json::Value>,
200 pub native_tools: Option<Vec<serde_json::Value>>,
201 pub stream: bool,
202}
203
204impl FakeLlmCall {
205 fn from_request(request: &LlmRequestPayload) -> Self {
206 Self {
207 provider: request.provider.clone(),
208 model: request.model.clone(),
209 system: request.system.clone(),
210 messages: request.messages.clone(),
211 native_tools: request.native_tools.clone(),
212 stream: request.stream,
213 }
214 }
215}
216
217thread_local! {
218 static FAKE_LLM_TURNS: RefCell<Vec<FakeLlmTurn>> = const { RefCell::new(Vec::new()) };
219 static FAKE_LLM_CALLS: RefCell<Vec<FakeLlmCall>> = const { RefCell::new(Vec::new()) };
220}
221
222#[must_use = "FakeLlmGuard asserts on drop; bind it to a `_guard` local"]
230pub fn install_fake_llm_script(script: FakeLlmScript) -> FakeLlmGuard {
231 FAKE_LLM_TURNS.with(|turns| {
232 let mut turns = turns.borrow_mut();
233 assert!(
234 turns.is_empty(),
235 "FakeLlmProvider: a script is already installed; drop the previous guard before installing a new one"
236 );
237 *turns = script.turns;
238 });
239 FAKE_LLM_CALLS.with(|calls| calls.borrow_mut().clear());
240 FakeLlmGuard { _priv: () }
241}
242
243pub fn fake_llm_captured_calls() -> Vec<FakeLlmCall> {
245 FAKE_LLM_CALLS.with(|calls| calls.borrow().clone())
246}
247
248#[must_use]
254pub struct FakeLlmGuard {
255 _priv: (),
256}
257
258impl Drop for FakeLlmGuard {
259 fn drop(&mut self) {
260 let remaining = FAKE_LLM_TURNS.with(|turns| std::mem::take(&mut *turns.borrow_mut()));
261 FAKE_LLM_CALLS.with(|calls| calls.borrow_mut().clear());
262 if std::thread::panicking() {
265 return;
266 }
267 assert!(
268 remaining.is_empty(),
269 "FakeLlmProvider script had {} unconsumed turn(s); did the code under test make fewer LLM calls than expected?",
270 remaining.len()
271 );
272 }
273}
274
275fn take_next_turn(request: &LlmRequestPayload) -> Result<FakeLlmTurn, VmError> {
278 FAKE_LLM_CALLS.with(|calls| {
279 calls.borrow_mut().push(FakeLlmCall::from_request(request));
280 });
281 FAKE_LLM_TURNS.with(|turns| {
282 let mut turns = turns.borrow_mut();
283 if turns.is_empty() {
284 Err(VmError::Runtime(
285 "FakeLlmProvider: no script installed (or script exhausted) — install_fake_llm_script() must precede llm_call(provider: \"fake\")".to_string()
286 ))
287 } else {
288 Ok(turns.remove(0))
289 }
290 })
291}
292
293fn fake_error_to_vm_error(err: &FakeLlmError) -> VmError {
296 let message = match err.retry_after_ms {
297 Some(ms) => {
298 let secs = (ms as f64 / 1000.0).max(0.0);
299 let sep = if err.message.is_empty() || err.message.ends_with('\n') {
300 ""
301 } else {
302 "\n"
303 };
304 format!("{}{sep}retry-after: {secs}\n", err.message)
305 }
306 None => err.message.clone(),
307 };
308 VmError::CategorizedError {
309 message,
310 category: err.category.clone(),
311 }
312}
313
314pub(crate) struct FakeLlmProvider;
316
317impl LlmProvider for FakeLlmProvider {
318 fn name(&self) -> &str {
319 "fake"
320 }
321
322 fn requires_model(&self) -> bool {
323 false
324 }
325}
326
327impl LlmProviderChat for FakeLlmProvider {
328 fn chat<'a>(
329 &'a self,
330 request: &'a LlmRequestPayload,
331 delta_tx: Option<DeltaSender>,
332 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<LlmResult, VmError>> + 'a>> {
333 Box::pin(self.chat_impl(request, delta_tx))
334 }
335}
336
337impl FakeLlmProvider {
338 pub(crate) fn should_intercept(provider: &str) -> bool {
340 provider == "fake"
341 }
342
343 pub(crate) async fn chat_impl(
344 &self,
345 request: &LlmRequestPayload,
346 delta_tx: Option<DeltaSender>,
347 ) -> Result<LlmResult, VmError> {
348 loop {
349 let turn = take_next_turn(request)?;
350 match turn {
351 FakeLlmTurn::Stalled(duration) => {
352 if !duration.is_zero() {
353 tokio::time::sleep(duration).await;
354 }
355 }
359 FakeLlmTurn::Error(err) => {
360 return Err(fake_error_to_vm_error(&err));
361 }
362 FakeLlmTurn::Stream(events) => {
363 return play_stream(request, events, delta_tx).await;
364 }
365 }
366 }
367 }
368}
369
370async fn play_stream(
371 request: &LlmRequestPayload,
372 events: Vec<FakeLlmEvent>,
373 delta_tx: Option<DeltaSender>,
374) -> Result<LlmResult, VmError> {
375 let mut text = String::new();
376 let mut tool_calls: Vec<serde_json::Value> = Vec::new();
377 let mut blocks: Vec<serde_json::Value> = Vec::new();
378 let mut stop_reason: Option<FakeStopReason> = None;
379 let mut next_tool_index: usize = 1;
380
381 for event in events {
382 match event {
383 FakeLlmEvent::Token(chunk) => {
384 if let Some(tx) = delta_tx.as_ref() {
385 let _ = tx.send(chunk.clone());
386 }
387 text.push_str(&chunk);
388 }
389 FakeLlmEvent::ToolCallDelta {
390 id,
391 name,
392 arguments,
393 } => {
394 let id = if id.is_empty() {
395 let auto = format!("fake_call_{}", next_tool_index);
396 next_tool_index += 1;
397 auto
398 } else {
399 id
400 };
401 tool_calls.push(serde_json::json!({
402 "id": id,
403 "type": "tool_call",
404 "name": name,
405 "arguments": arguments,
406 }));
407 blocks.push(serde_json::json!({
408 "type": "tool_call",
409 "id": id,
410 "name": name,
411 "arguments": arguments,
412 "visibility": "internal",
413 }));
414 }
415 FakeLlmEvent::Stall(duration) => {
416 if !duration.is_zero() {
417 tokio::time::sleep(duration).await;
418 }
419 }
420 FakeLlmEvent::Done(reason) => {
421 stop_reason = Some(reason);
422 break;
423 }
424 }
425 }
426
427 if !text.is_empty() {
428 let text_block = serde_json::json!({
431 "type": "output_text",
432 "text": text,
433 "visibility": "public",
434 });
435 blocks.insert(0, text_block);
436 }
437
438 let stop_reason = stop_reason.unwrap_or(if tool_calls.is_empty() {
439 FakeStopReason::EndTurn
440 } else {
441 FakeStopReason::ToolUse
442 });
443
444 Ok(LlmResult {
445 text,
446 tool_calls,
447 input_tokens: count_input_tokens(&request.messages),
448 output_tokens: 0,
449 cache_read_tokens: 0,
450 cache_write_tokens: 0,
451 model: request.model.clone(),
452 provider: "fake".to_string(),
453 thinking: None,
454 thinking_summary: None,
455 stop_reason: Some(stop_reason.as_str().to_string()),
456 blocks,
457 logprobs: Vec::new(),
458 telemetry: ProviderTelemetry::default(),
459 })
460}
461
462fn count_input_tokens(messages: &[serde_json::Value]) -> i64 {
466 fn collect(value: &serde_json::Value, out: &mut String) {
467 match value {
468 serde_json::Value::String(text) => {
469 out.push_str(text);
470 out.push('\n');
471 }
472 serde_json::Value::Array(items) => {
473 for item in items {
474 collect(item, out);
475 }
476 }
477 serde_json::Value::Object(map) => {
478 for value in map.values() {
479 collect(value, out);
480 }
481 }
482 _ => {}
483 }
484 }
485 let mut buf = String::new();
486 for message in messages {
487 collect(message, &mut buf);
488 }
489 buf.len() as i64
490}
491
492#[cfg(test)]
493mod tests {
494 use super::*;
495 use crate::llm::api::ThinkingConfig;
496 use crate::llm::api::{LlmRequestPayload, LlmRouteFallback, OutputFormat};
497
498 fn fake_request() -> LlmRequestPayload {
499 LlmRequestPayload {
500 provider: "fake".to_string(),
501 model: "fake-model".to_string(),
502 api_key: String::new(),
503 fallback_chain: Vec::new(),
504 route_fallbacks: Vec::<LlmRouteFallback>::new(),
505 messages: vec![serde_json::json!({"role": "user", "content": "hello"})],
506 system: None,
507 max_tokens: 64,
508 temperature: None,
509 top_p: None,
510 top_k: None,
511 logprobs: false,
512 top_logprobs: None,
513 stop: None,
514 seed: None,
515 frequency_penalty: None,
516 presence_penalty: None,
517 output_format: OutputFormat::Text,
518 response_format: None,
519 json_schema: None,
520 thinking: ThinkingConfig::Disabled,
521 anthropic_beta_features: Vec::new(),
522 vision: false,
523 native_tools: None,
524 tool_choice: None,
525 cache: false,
526 timeout: None,
527 stream: true,
528 provider_overrides: None,
529 prefill: None,
530 session_id: None,
531 }
532 }
533
534 fn current_thread_runtime() -> tokio::runtime::Runtime {
535 tokio::runtime::Builder::new_current_thread()
536 .enable_all()
537 .start_paused(false)
538 .build()
539 .expect("runtime")
540 }
541
542 #[test]
543 fn streaming_turn_emits_deltas_in_order() {
544 let runtime = current_thread_runtime();
545 let _guard = install_fake_llm_script(FakeLlmScript::streaming(vec![
546 FakeLlmEvent::Token("hello ".into()),
547 FakeLlmEvent::Token("world".into()),
548 FakeLlmEvent::Done(FakeStopReason::EndTurn),
549 ]));
550
551 runtime.block_on(async {
552 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<String>();
553 let result = FakeLlmProvider
554 .chat_impl(&fake_request(), Some(tx))
555 .await
556 .expect("fake call should succeed");
557
558 let mut deltas = Vec::new();
559 while let Ok(delta) = rx.try_recv() {
560 deltas.push(delta);
561 }
562
563 assert_eq!(deltas, vec!["hello ".to_string(), "world".to_string()]);
564 assert_eq!(result.text, "hello world");
565 assert_eq!(result.provider, "fake");
566 assert_eq!(result.stop_reason.as_deref(), Some("end_turn"));
567 assert_eq!(result.blocks.len(), 1);
568 assert_eq!(result.blocks[0]["type"].as_str(), Some("output_text"));
569 assert_eq!(result.blocks[0]["text"].as_str(), Some("hello world"));
570 });
571 assert_eq!(fake_llm_captured_calls().len(), 1);
572 }
573
574 #[test]
575 fn tool_call_deltas_become_tool_calls_and_blocks() {
576 let runtime = current_thread_runtime();
577 let _guard = install_fake_llm_script(FakeLlmScript::streaming(vec![
578 FakeLlmEvent::Token("calling tool".into()),
579 FakeLlmEvent::ToolCallDelta {
580 id: String::new(),
581 name: "search".into(),
582 arguments: serde_json::json!({"q": "harn"}),
583 },
584 FakeLlmEvent::Done(FakeStopReason::ToolUse),
585 ]));
586
587 runtime.block_on(async {
588 let result = FakeLlmProvider
589 .chat_impl(&fake_request(), None)
590 .await
591 .expect("fake call should succeed");
592
593 assert_eq!(result.tool_calls.len(), 1);
594 assert_eq!(result.tool_calls[0]["name"].as_str(), Some("search"));
595 assert_eq!(result.tool_calls[0]["id"].as_str(), Some("fake_call_1"));
596 assert_eq!(
597 result.tool_calls[0]["arguments"]["q"].as_str(),
598 Some("harn")
599 );
600 assert_eq!(result.stop_reason.as_deref(), Some("tool_use"));
601 assert_eq!(result.blocks[0]["type"].as_str(), Some("output_text"));
603 assert_eq!(result.blocks[1]["type"].as_str(), Some("tool_call"));
604 assert_eq!(result.blocks[1]["name"].as_str(), Some("search"));
605 });
606 }
607
608 #[test]
609 fn error_turn_returns_categorized_error() {
610 let runtime = current_thread_runtime();
611 let _guard = install_fake_llm_script(FakeLlmScript::erroring(
612 ErrorCategory::RateLimit,
613 "throttled",
614 ));
615
616 runtime.block_on(async {
617 let err = FakeLlmProvider
618 .chat_impl(&fake_request(), None)
619 .await
620 .expect_err("fake error turn should fail");
621 match err {
622 VmError::CategorizedError { message, category } => {
623 assert_eq!(category, ErrorCategory::RateLimit);
624 assert!(
625 message.contains("throttled"),
626 "error message should pass through: {message}"
627 );
628 }
629 other => panic!("expected CategorizedError, got {other:?}"),
630 }
631 });
632 }
633
634 #[test]
635 fn error_turn_embeds_retry_after_hint() {
636 let runtime = current_thread_runtime();
637 let _guard = install_fake_llm_script(FakeLlmScript::default().push(FakeLlmTurn::Error(
638 FakeLlmError::new(ErrorCategory::RateLimit, "throttled").with_retry_after_ms(2_500),
639 )));
640
641 runtime.block_on(async {
642 let err = FakeLlmProvider
643 .chat_impl(&fake_request(), None)
644 .await
645 .expect_err("fake error turn should fail");
646 let VmError::CategorizedError { message, .. } = err else {
647 panic!("expected CategorizedError");
648 };
649 assert!(
650 message.contains("retry-after: 2.5"),
651 "retry-after hint should be present in synthetic message: {message}"
652 );
653 });
654 }
655
656 #[test]
657 fn stalled_turn_advances_under_paused_clock() {
658 let runtime = tokio::runtime::Builder::new_current_thread()
659 .enable_all()
660 .start_paused(true)
661 .build()
662 .expect("paused runtime");
663 let _guard = install_fake_llm_script(
664 FakeLlmScript::default()
665 .push(FakeLlmTurn::Stalled(Duration::from_secs(60)))
666 .push(FakeLlmTurn::stream(vec![
667 FakeLlmEvent::Token("done".into()),
668 FakeLlmEvent::Done(FakeStopReason::EndTurn),
669 ])),
670 );
671
672 runtime.block_on(async {
673 let request = fake_request();
674 let chat = FakeLlmProvider.chat_impl(&request, None);
675 tokio::pin!(chat);
676
677 let polled = futures::poll!(&mut chat);
681 assert!(
682 matches!(polled, std::task::Poll::Pending),
683 "fake provider should be parked on the stall"
684 );
685
686 tokio::time::advance(Duration::from_secs(60)).await;
687 let result = chat.await.expect("after advance, fake call resolves");
688 assert_eq!(result.text, "done");
689 });
690 }
691
692 #[test]
693 fn multiple_turns_consumed_in_fifo_order() {
694 let runtime = current_thread_runtime();
695 let _guard = install_fake_llm_script(
696 FakeLlmScript::default()
697 .push(FakeLlmTurn::stream(vec![
698 FakeLlmEvent::Token("first".into()),
699 FakeLlmEvent::Done(FakeStopReason::EndTurn),
700 ]))
701 .push(FakeLlmTurn::stream(vec![
702 FakeLlmEvent::Token("second".into()),
703 FakeLlmEvent::Done(FakeStopReason::EndTurn),
704 ])),
705 );
706
707 runtime.block_on(async {
708 let first = FakeLlmProvider
709 .chat_impl(&fake_request(), None)
710 .await
711 .expect("first call");
712 let second = FakeLlmProvider
713 .chat_impl(&fake_request(), None)
714 .await
715 .expect("second call");
716 assert_eq!(first.text, "first");
717 assert_eq!(second.text, "second");
718 });
719
720 let calls = fake_llm_captured_calls();
721 assert_eq!(calls.len(), 2);
722 assert!(calls.iter().all(|c| c.provider == "fake"));
723 }
724
725 #[test]
726 #[should_panic(expected = "no script installed")]
727 fn calling_without_script_panics_with_explanatory_error() {
728 let runtime = current_thread_runtime();
729 runtime
731 .block_on(async {
732 FakeLlmProvider
733 .chat_impl(&fake_request(), None)
734 .await
735 .map_err(|e| e.to_string())
736 })
737 .unwrap();
738 }
739
740 #[test]
741 #[should_panic(expected = "unconsumed turn")]
742 fn drop_guard_asserts_on_unused_turns() {
743 let guard =
744 install_fake_llm_script(FakeLlmScript::default().push(FakeLlmTurn::stream(vec![
745 FakeLlmEvent::Done(FakeStopReason::EndTurn),
746 ])));
747 drop(guard);
749 }
750}