1use std::cell::RefCell;
49use std::time::Duration;
50
51use crate::llm::api::{DeltaSender, LlmRequestPayload, LlmResult};
52use crate::llm::provider::{LlmProvider, LlmProviderChat};
53use crate::value::{ErrorCategory, VmError};
54
55#[derive(Clone, Debug)]
57pub enum FakeLlmTurn {
58 Stream(Vec<FakeLlmEvent>),
61 Error(FakeLlmError),
64 Stalled(Duration),
68}
69
70impl FakeLlmTurn {
71 pub fn stream(events: impl IntoIterator<Item = FakeLlmEvent>) -> Self {
73 Self::Stream(events.into_iter().collect())
74 }
75
76 pub fn error(category: ErrorCategory, message: impl Into<String>) -> Self {
78 Self::Error(FakeLlmError {
79 category,
80 message: message.into(),
81 retry_after_ms: None,
82 })
83 }
84}
85
86#[derive(Clone, Debug)]
88pub enum FakeLlmEvent {
89 Token(String),
92 ToolCallDelta {
95 id: String,
96 name: String,
97 arguments: serde_json::Value,
98 },
99 Stall(Duration),
102 Done(FakeStopReason),
105}
106
107#[derive(Clone, Debug, PartialEq, Eq)]
109pub enum FakeStopReason {
110 EndTurn,
111 ToolUse,
112 MaxTokens,
113 StopSequence,
114 Custom(String),
115}
116
117impl FakeStopReason {
118 fn as_str(&self) -> &str {
119 match self {
120 Self::EndTurn => "end_turn",
121 Self::ToolUse => "tool_use",
122 Self::MaxTokens => "max_tokens",
123 Self::StopSequence => "stop_sequence",
124 Self::Custom(value) => value.as_str(),
125 }
126 }
127}
128
129#[derive(Clone, Debug)]
131pub struct FakeLlmError {
132 pub category: ErrorCategory,
133 pub message: String,
134 pub retry_after_ms: Option<u64>,
138}
139
140impl FakeLlmError {
141 pub fn new(category: ErrorCategory, message: impl Into<String>) -> Self {
142 Self {
143 category,
144 message: message.into(),
145 retry_after_ms: None,
146 }
147 }
148
149 pub fn with_retry_after_ms(mut self, ms: u64) -> Self {
150 self.retry_after_ms = Some(ms);
151 self
152 }
153}
154
155#[derive(Clone, Debug, Default)]
162pub struct FakeLlmScript {
163 pub turns: Vec<FakeLlmTurn>,
164}
165
166impl FakeLlmScript {
167 pub fn new() -> Self {
168 Self::default()
169 }
170
171 pub fn streaming(events: impl IntoIterator<Item = FakeLlmEvent>) -> Self {
173 Self {
174 turns: vec![FakeLlmTurn::stream(events)],
175 }
176 }
177
178 pub fn erroring(category: ErrorCategory, message: impl Into<String>) -> Self {
180 Self {
181 turns: vec![FakeLlmTurn::error(category, message)],
182 }
183 }
184
185 pub fn push(mut self, turn: FakeLlmTurn) -> Self {
187 self.turns.push(turn);
188 self
189 }
190}
191
192#[derive(Clone, Debug)]
195pub struct FakeLlmCall {
196 pub provider: String,
197 pub model: String,
198 pub system: Option<String>,
199 pub messages: Vec<serde_json::Value>,
200 pub native_tools: Option<Vec<serde_json::Value>>,
201 pub stream: bool,
202}
203
204impl FakeLlmCall {
205 fn from_request(request: &LlmRequestPayload) -> Self {
206 Self {
207 provider: request.provider.clone(),
208 model: request.model.clone(),
209 system: request.system.clone(),
210 messages: request.messages.clone(),
211 native_tools: request.native_tools.clone(),
212 stream: request.stream,
213 }
214 }
215}
216
217thread_local! {
218 static FAKE_LLM_TURNS: RefCell<Vec<FakeLlmTurn>> = const { RefCell::new(Vec::new()) };
219 static FAKE_LLM_CALLS: RefCell<Vec<FakeLlmCall>> = const { RefCell::new(Vec::new()) };
220}
221
222#[must_use = "FakeLlmGuard asserts on drop; bind it to a `_guard` local"]
230pub fn install_fake_llm_script(script: FakeLlmScript) -> FakeLlmGuard {
231 FAKE_LLM_TURNS.with(|turns| {
232 let mut turns = turns.borrow_mut();
233 assert!(
234 turns.is_empty(),
235 "FakeLlmProvider: a script is already installed; drop the previous guard before installing a new one"
236 );
237 *turns = script.turns;
238 });
239 FAKE_LLM_CALLS.with(|calls| calls.borrow_mut().clear());
240 FakeLlmGuard { _priv: () }
241}
242
243pub fn fake_llm_captured_calls() -> Vec<FakeLlmCall> {
245 FAKE_LLM_CALLS.with(|calls| calls.borrow().clone())
246}
247
248#[must_use]
254pub struct FakeLlmGuard {
255 _priv: (),
256}
257
258impl Drop for FakeLlmGuard {
259 fn drop(&mut self) {
260 let remaining = FAKE_LLM_TURNS.with(|turns| std::mem::take(&mut *turns.borrow_mut()));
261 FAKE_LLM_CALLS.with(|calls| calls.borrow_mut().clear());
262 if std::thread::panicking() {
265 return;
266 }
267 assert!(
268 remaining.is_empty(),
269 "FakeLlmProvider script had {} unconsumed turn(s); did the code under test make fewer LLM calls than expected?",
270 remaining.len()
271 );
272 }
273}
274
275fn take_next_turn(request: &LlmRequestPayload) -> Result<FakeLlmTurn, VmError> {
278 FAKE_LLM_CALLS.with(|calls| {
279 calls.borrow_mut().push(FakeLlmCall::from_request(request));
280 });
281 FAKE_LLM_TURNS.with(|turns| {
282 let mut turns = turns.borrow_mut();
283 if turns.is_empty() {
284 Err(VmError::Runtime(
285 "FakeLlmProvider: no script installed (or script exhausted) — install_fake_llm_script() must precede llm_call(provider: \"fake\")".to_string()
286 ))
287 } else {
288 Ok(turns.remove(0))
289 }
290 })
291}
292
293fn fake_error_to_vm_error(err: &FakeLlmError) -> VmError {
296 let message = match err.retry_after_ms {
297 Some(ms) => {
298 let secs = (ms as f64 / 1000.0).max(0.0);
299 let sep = if err.message.is_empty() || err.message.ends_with('\n') {
300 ""
301 } else {
302 "\n"
303 };
304 format!("{}{sep}retry-after: {secs}\n", err.message)
305 }
306 None => err.message.clone(),
307 };
308 VmError::CategorizedError {
309 message,
310 category: err.category.clone(),
311 }
312}
313
314pub(crate) struct FakeLlmProvider;
316
317impl LlmProvider for FakeLlmProvider {
318 fn name(&self) -> &str {
319 "fake"
320 }
321
322 fn requires_model(&self) -> bool {
323 false
324 }
325}
326
327impl LlmProviderChat for FakeLlmProvider {
328 fn chat<'a>(
329 &'a self,
330 request: &'a LlmRequestPayload,
331 delta_tx: Option<DeltaSender>,
332 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<LlmResult, VmError>> + 'a>> {
333 Box::pin(self.chat_impl(request, delta_tx))
334 }
335}
336
337impl FakeLlmProvider {
338 pub(crate) fn should_intercept(provider: &str) -> bool {
340 provider == "fake"
341 }
342
343 pub(crate) async fn chat_impl(
344 &self,
345 request: &LlmRequestPayload,
346 delta_tx: Option<DeltaSender>,
347 ) -> Result<LlmResult, VmError> {
348 loop {
349 let turn = take_next_turn(request)?;
350 match turn {
351 FakeLlmTurn::Stalled(duration) => {
352 if !duration.is_zero() {
353 tokio::time::sleep(duration).await;
354 }
355 }
359 FakeLlmTurn::Error(err) => {
360 return Err(fake_error_to_vm_error(&err));
361 }
362 FakeLlmTurn::Stream(events) => {
363 return play_stream(request, events, delta_tx).await;
364 }
365 }
366 }
367 }
368}
369
370async fn play_stream(
371 request: &LlmRequestPayload,
372 events: Vec<FakeLlmEvent>,
373 delta_tx: Option<DeltaSender>,
374) -> Result<LlmResult, VmError> {
375 let mut text = String::new();
376 let mut tool_calls: Vec<serde_json::Value> = Vec::new();
377 let mut blocks: Vec<serde_json::Value> = Vec::new();
378 let mut stop_reason: Option<FakeStopReason> = None;
379 let mut next_tool_index: usize = 1;
380
381 for event in events {
382 match event {
383 FakeLlmEvent::Token(chunk) => {
384 if let Some(tx) = delta_tx.as_ref() {
385 let _ = tx.send(chunk.clone());
386 }
387 text.push_str(&chunk);
388 }
389 FakeLlmEvent::ToolCallDelta {
390 id,
391 name,
392 arguments,
393 } => {
394 let id = if id.is_empty() {
395 let auto = format!("fake_call_{}", next_tool_index);
396 next_tool_index += 1;
397 auto
398 } else {
399 id
400 };
401 tool_calls.push(serde_json::json!({
402 "id": id,
403 "type": "tool_call",
404 "name": name,
405 "arguments": arguments,
406 }));
407 blocks.push(serde_json::json!({
408 "type": "tool_call",
409 "id": id,
410 "name": name,
411 "arguments": arguments,
412 "visibility": "internal",
413 }));
414 }
415 FakeLlmEvent::Stall(duration) => {
416 if !duration.is_zero() {
417 tokio::time::sleep(duration).await;
418 }
419 }
420 FakeLlmEvent::Done(reason) => {
421 stop_reason = Some(reason);
422 break;
423 }
424 }
425 }
426
427 if !text.is_empty() {
428 let text_block = serde_json::json!({
431 "type": "output_text",
432 "text": text,
433 "visibility": "public",
434 });
435 blocks.insert(0, text_block);
436 }
437
438 let stop_reason = stop_reason.unwrap_or(if tool_calls.is_empty() {
439 FakeStopReason::EndTurn
440 } else {
441 FakeStopReason::ToolUse
442 });
443
444 Ok(LlmResult {
445 text,
446 tool_calls,
447 input_tokens: count_input_tokens(&request.messages),
448 output_tokens: 0,
449 cache_read_tokens: 0,
450 cache_write_tokens: 0,
451 model: request.model.clone(),
452 provider: "fake".to_string(),
453 thinking: None,
454 thinking_summary: None,
455 stop_reason: Some(stop_reason.as_str().to_string()),
456 blocks,
457 })
458}
459
460fn count_input_tokens(messages: &[serde_json::Value]) -> i64 {
464 fn collect(value: &serde_json::Value, out: &mut String) {
465 match value {
466 serde_json::Value::String(text) => {
467 out.push_str(text);
468 out.push('\n');
469 }
470 serde_json::Value::Array(items) => {
471 for item in items {
472 collect(item, out);
473 }
474 }
475 serde_json::Value::Object(map) => {
476 for value in map.values() {
477 collect(value, out);
478 }
479 }
480 _ => {}
481 }
482 }
483 let mut buf = String::new();
484 for message in messages {
485 collect(message, &mut buf);
486 }
487 buf.len() as i64
488}
489
490#[cfg(test)]
491mod tests {
492 use super::*;
493 use crate::llm::api::ThinkingConfig;
494 use crate::llm::api::{LlmRequestPayload, LlmRouteFallback, OutputFormat};
495
496 fn fake_request() -> LlmRequestPayload {
497 LlmRequestPayload {
498 provider: "fake".to_string(),
499 model: "fake-model".to_string(),
500 api_key: String::new(),
501 fallback_chain: Vec::new(),
502 route_fallbacks: Vec::<LlmRouteFallback>::new(),
503 messages: vec![serde_json::json!({"role": "user", "content": "hello"})],
504 system: None,
505 max_tokens: 64,
506 temperature: None,
507 top_p: None,
508 top_k: None,
509 stop: None,
510 seed: None,
511 frequency_penalty: None,
512 presence_penalty: None,
513 output_format: OutputFormat::Text,
514 response_format: None,
515 json_schema: None,
516 thinking: ThinkingConfig::Disabled,
517 anthropic_beta_features: Vec::new(),
518 vision: false,
519 native_tools: None,
520 tool_choice: None,
521 cache: false,
522 timeout: None,
523 stream: true,
524 provider_overrides: None,
525 prefill: None,
526 session_id: None,
527 }
528 }
529
530 fn current_thread_runtime() -> tokio::runtime::Runtime {
531 tokio::runtime::Builder::new_current_thread()
532 .enable_all()
533 .start_paused(false)
534 .build()
535 .expect("runtime")
536 }
537
538 #[test]
539 fn streaming_turn_emits_deltas_in_order() {
540 let runtime = current_thread_runtime();
541 let _guard = install_fake_llm_script(FakeLlmScript::streaming(vec![
542 FakeLlmEvent::Token("hello ".into()),
543 FakeLlmEvent::Token("world".into()),
544 FakeLlmEvent::Done(FakeStopReason::EndTurn),
545 ]));
546
547 runtime.block_on(async {
548 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<String>();
549 let result = FakeLlmProvider
550 .chat_impl(&fake_request(), Some(tx))
551 .await
552 .expect("fake call should succeed");
553
554 let mut deltas = Vec::new();
555 while let Ok(delta) = rx.try_recv() {
556 deltas.push(delta);
557 }
558
559 assert_eq!(deltas, vec!["hello ".to_string(), "world".to_string()]);
560 assert_eq!(result.text, "hello world");
561 assert_eq!(result.provider, "fake");
562 assert_eq!(result.stop_reason.as_deref(), Some("end_turn"));
563 assert_eq!(result.blocks.len(), 1);
564 assert_eq!(result.blocks[0]["type"].as_str(), Some("output_text"));
565 assert_eq!(result.blocks[0]["text"].as_str(), Some("hello world"));
566 });
567 assert_eq!(fake_llm_captured_calls().len(), 1);
568 }
569
570 #[test]
571 fn tool_call_deltas_become_tool_calls_and_blocks() {
572 let runtime = current_thread_runtime();
573 let _guard = install_fake_llm_script(FakeLlmScript::streaming(vec![
574 FakeLlmEvent::Token("calling tool".into()),
575 FakeLlmEvent::ToolCallDelta {
576 id: String::new(),
577 name: "search".into(),
578 arguments: serde_json::json!({"q": "harn"}),
579 },
580 FakeLlmEvent::Done(FakeStopReason::ToolUse),
581 ]));
582
583 runtime.block_on(async {
584 let result = FakeLlmProvider
585 .chat_impl(&fake_request(), None)
586 .await
587 .expect("fake call should succeed");
588
589 assert_eq!(result.tool_calls.len(), 1);
590 assert_eq!(result.tool_calls[0]["name"].as_str(), Some("search"));
591 assert_eq!(result.tool_calls[0]["id"].as_str(), Some("fake_call_1"));
592 assert_eq!(
593 result.tool_calls[0]["arguments"]["q"].as_str(),
594 Some("harn")
595 );
596 assert_eq!(result.stop_reason.as_deref(), Some("tool_use"));
597 assert_eq!(result.blocks[0]["type"].as_str(), Some("output_text"));
599 assert_eq!(result.blocks[1]["type"].as_str(), Some("tool_call"));
600 assert_eq!(result.blocks[1]["name"].as_str(), Some("search"));
601 });
602 }
603
604 #[test]
605 fn error_turn_returns_categorized_error() {
606 let runtime = current_thread_runtime();
607 let _guard = install_fake_llm_script(FakeLlmScript::erroring(
608 ErrorCategory::RateLimit,
609 "throttled",
610 ));
611
612 runtime.block_on(async {
613 let err = FakeLlmProvider
614 .chat_impl(&fake_request(), None)
615 .await
616 .expect_err("fake error turn should fail");
617 match err {
618 VmError::CategorizedError { message, category } => {
619 assert_eq!(category, ErrorCategory::RateLimit);
620 assert!(
621 message.contains("throttled"),
622 "error message should pass through: {message}"
623 );
624 }
625 other => panic!("expected CategorizedError, got {other:?}"),
626 }
627 });
628 }
629
630 #[test]
631 fn error_turn_embeds_retry_after_hint() {
632 let runtime = current_thread_runtime();
633 let _guard = install_fake_llm_script(FakeLlmScript::default().push(FakeLlmTurn::Error(
634 FakeLlmError::new(ErrorCategory::RateLimit, "throttled").with_retry_after_ms(2_500),
635 )));
636
637 runtime.block_on(async {
638 let err = FakeLlmProvider
639 .chat_impl(&fake_request(), None)
640 .await
641 .expect_err("fake error turn should fail");
642 let VmError::CategorizedError { message, .. } = err else {
643 panic!("expected CategorizedError");
644 };
645 assert!(
646 message.contains("retry-after: 2.5"),
647 "retry-after hint should be present in synthetic message: {message}"
648 );
649 });
650 }
651
652 #[test]
653 fn stalled_turn_advances_under_paused_clock() {
654 let runtime = tokio::runtime::Builder::new_current_thread()
655 .enable_all()
656 .start_paused(true)
657 .build()
658 .expect("paused runtime");
659 let _guard = install_fake_llm_script(
660 FakeLlmScript::default()
661 .push(FakeLlmTurn::Stalled(Duration::from_secs(60)))
662 .push(FakeLlmTurn::stream(vec![
663 FakeLlmEvent::Token("done".into()),
664 FakeLlmEvent::Done(FakeStopReason::EndTurn),
665 ])),
666 );
667
668 runtime.block_on(async {
669 let request = fake_request();
670 let chat = FakeLlmProvider.chat_impl(&request, None);
671 tokio::pin!(chat);
672
673 let polled = futures::poll!(&mut chat);
677 assert!(
678 matches!(polled, std::task::Poll::Pending),
679 "fake provider should be parked on the stall"
680 );
681
682 tokio::time::advance(Duration::from_secs(60)).await;
683 let result = chat.await.expect("after advance, fake call resolves");
684 assert_eq!(result.text, "done");
685 });
686 }
687
688 #[test]
689 fn multiple_turns_consumed_in_fifo_order() {
690 let runtime = current_thread_runtime();
691 let _guard = install_fake_llm_script(
692 FakeLlmScript::default()
693 .push(FakeLlmTurn::stream(vec![
694 FakeLlmEvent::Token("first".into()),
695 FakeLlmEvent::Done(FakeStopReason::EndTurn),
696 ]))
697 .push(FakeLlmTurn::stream(vec![
698 FakeLlmEvent::Token("second".into()),
699 FakeLlmEvent::Done(FakeStopReason::EndTurn),
700 ])),
701 );
702
703 runtime.block_on(async {
704 let first = FakeLlmProvider
705 .chat_impl(&fake_request(), None)
706 .await
707 .expect("first call");
708 let second = FakeLlmProvider
709 .chat_impl(&fake_request(), None)
710 .await
711 .expect("second call");
712 assert_eq!(first.text, "first");
713 assert_eq!(second.text, "second");
714 });
715
716 let calls = fake_llm_captured_calls();
717 assert_eq!(calls.len(), 2);
718 assert!(calls.iter().all(|c| c.provider == "fake"));
719 }
720
721 #[test]
722 #[should_panic(expected = "no script installed")]
723 fn calling_without_script_panics_with_explanatory_error() {
724 let runtime = current_thread_runtime();
725 runtime
727 .block_on(async {
728 FakeLlmProvider
729 .chat_impl(&fake_request(), None)
730 .await
731 .map_err(|e| e.to_string())
732 })
733 .unwrap();
734 }
735
736 #[test]
737 #[should_panic(expected = "unconsumed turn")]
738 fn drop_guard_asserts_on_unused_turns() {
739 let guard =
740 install_fake_llm_script(FakeLlmScript::default().push(FakeLlmTurn::stream(vec![
741 FakeLlmEvent::Done(FakeStopReason::EndTurn),
742 ])));
743 drop(guard);
745 }
746}