1use std::collections::HashMap;
2use std::fs::OpenOptions;
3use std::io::Write;
4use std::path::{Path, PathBuf};
5use std::sync::Arc;
6use std::time::Instant;
7
8use tokio::io::{AsyncReadExt, AsyncWriteExt};
9use tokio::net::TcpListener;
10use tokio::task::JoinHandle;
11
12use ai_agents_core::{
13 ChatMessage, LLMChunk, LLMConfig, LLMError, LLMFeature, LLMProvider, LLMResponse, Tool,
14 ToolResult,
15};
16use ai_agents_llm::providers::{ProviderType, UnifiedLLMProvider};
17use ai_agents_llm::{FinishReason, LLMRegistry};
18use ai_agents_runtime::spec::AgentSpec;
19use ai_agents_tools::{ToolRegistry, create_builtin_registry};
20use async_trait::async_trait;
21use futures::Stream;
22use parking_lot::Mutex;
23use serde::{Deserialize, Serialize};
24use serde_json::{Value, json};
25use sha2::{Digest, Sha256};
26
27use crate::evidence::{ToolExecutionRecord, ToolExecutionSource};
28use crate::{EvalError, Result};
29
30#[derive(Debug, Clone, Serialize, Deserialize, Default)]
32pub struct FixturesConfig {
33 #[serde(default)]
35 pub context: Option<Value>,
36 #[serde(default)]
38 pub context_file: Option<PathBuf>,
39 #[serde(default)]
41 pub tools: HashMap<String, ToolMockConfig>,
42 #[serde(default)]
44 pub llm: LlmFixtureConfig,
45 #[serde(default)]
47 pub mock_server: Option<MockServerConfig>,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct ToolMockConfig {
53 #[serde(default = "default_true")]
55 pub success: bool,
56 #[serde(default)]
58 pub output: Value,
59}
60
61impl Default for ToolMockConfig {
62 fn default() -> Self {
63 Self {
64 success: true,
65 output: Value::Null,
66 }
67 }
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize, Default)]
72pub struct LlmFixtureConfig {
73 #[serde(default)]
75 pub mode: LlmFixtureMode,
76 #[serde(default)]
78 pub cassette: Option<PathBuf>,
79 #[serde(default)]
81 pub responses: Vec<String>,
82 #[serde(default)]
84 pub responses_by_alias: HashMap<String, Vec<String>>,
85 #[serde(default)]
87 pub errors_by_alias: HashMap<String, String>,
88}
89
90#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default, PartialEq, Eq)]
92#[serde(rename_all = "snake_case")]
93pub enum LlmFixtureMode {
94 #[default]
95 Real,
96 Mock,
97 Replay,
98 Record,
99}
100
101#[derive(Debug, Clone, Serialize, Deserialize, Default)]
103pub struct MockServerConfig {
104 #[serde(default)]
106 pub enabled: bool,
107 #[serde(default)]
109 pub port: Option<u16>,
110 #[serde(default)]
112 pub routes: Vec<Value>,
113}
114
115#[derive(Debug, Clone, Deserialize)]
117struct MockRoute {
118 method: String,
120 path: String,
122 #[serde(default = "default_status")]
124 status: u16,
125 #[serde(default)]
127 headers: HashMap<String, String>,
128 #[serde(default)]
130 body: Value,
131}
132
133pub struct MockServerHandle {
135 base_url: String,
137 task: JoinHandle<()>,
139}
140
141impl MockServerHandle {
142 pub fn context(&self) -> HashMap<String, Value> {
143 let mut context = HashMap::new();
144 context.insert(
145 "mock_server".to_string(),
146 serde_json::json!({"base_url": self.base_url}),
147 );
148 context
149 }
150}
151
152impl Drop for MockServerHandle {
153 fn drop(&mut self) {
154 self.task.abort();
155 }
156}
157
158#[derive(Clone, Default)]
160pub struct RecordingToolLog {
161 inner: Arc<Mutex<Vec<ToolExecutionRecord>>>,
163}
164
165impl RecordingToolLog {
166 pub fn new() -> Self {
167 Self::default()
168 }
169
170 pub fn len(&self) -> usize {
171 self.inner.lock().len()
172 }
173
174 pub fn push(&self, record: ToolExecutionRecord) {
175 self.inner.lock().push(record);
176 }
177
178 pub fn records_since(&self, index: usize) -> Vec<ToolExecutionRecord> {
179 self.inner.lock().iter().skip(index).cloned().collect()
180 }
181}
182
183pub fn resolve_fixture_context(
184 config: &FixturesConfig,
185 base_dir: &Path,
186) -> Result<HashMap<String, Value>> {
187 let mut result = HashMap::new();
188 if let Some(path) = &config.context_file {
189 let resolved = resolve_path(base_dir, path);
190 let content = std::fs::read_to_string(&resolved).map_err(|error| {
191 EvalError::Config(format!(
192 "failed to read context_file '{}': {}",
193 resolved.display(),
194 error
195 ))
196 })?;
197 let value: Value = serde_json::from_str(&content).map_err(|error| {
198 EvalError::Config(format!(
199 "failed to parse context_file '{}': {}",
200 resolved.display(),
201 error
202 ))
203 })?;
204 merge_object_into_map(&mut result, value)?;
205 }
206 if let Some(value) = &config.context {
207 merge_object_into_map(&mut result, value.clone())?;
208 }
209 Ok(result)
210}
211
212fn merge_object_into_map(target: &mut HashMap<String, Value>, value: Value) -> Result<()> {
213 let Value::Object(map) = value else {
214 return Err(EvalError::Config(
215 "fixture context must be a JSON object".into(),
216 ));
217 };
218 for (key, value) in map {
219 target.insert(key, value);
220 }
221 Ok(())
222}
223
224fn resolve_path(base_dir: &Path, path: &Path) -> PathBuf {
225 if path.is_absolute() {
226 path.to_path_buf()
227 } else {
228 base_dir.join(path)
229 }
230}
231
232pub async fn start_mock_server(
233 config: Option<&MockServerConfig>,
234) -> Result<Option<MockServerHandle>> {
235 let Some(config) = config else {
236 return Ok(None);
237 };
238 if !config.enabled {
239 return Ok(None);
240 }
241 let routes = config
242 .routes
243 .iter()
244 .cloned()
245 .map(serde_json::from_value::<MockRoute>)
246 .collect::<std::result::Result<Vec<_>, _>>()?;
247 let port = config.port.unwrap_or(0);
248 let listener = TcpListener::bind(("127.0.0.1", port))
249 .await
250 .map_err(|error| EvalError::Runtime(format!("failed to start mock server: {}", error)))?;
251 let addr = listener.local_addr().map_err(|error| {
252 EvalError::Runtime(format!("failed to read mock server addr: {}", error))
253 })?;
254 let base_url = format!("http://{}", addr);
255 let task = tokio::spawn(async move {
256 loop {
257 let Ok((stream, _)) = listener.accept().await else {
258 break;
259 };
260 let routes = routes.clone();
261 tokio::spawn(async move {
262 let _ = handle_mock_connection(stream, routes).await;
263 });
264 }
265 });
266 Ok(Some(MockServerHandle { base_url, task }))
267}
268
269async fn handle_mock_connection(
270 mut stream: tokio::net::TcpStream,
271 routes: Vec<MockRoute>,
272) -> std::io::Result<()> {
273 let mut buffer = vec![0_u8; 8192];
274 let read = stream.read(&mut buffer).await?;
275 let request = String::from_utf8_lossy(&buffer[..read]);
276 let first_line = request.lines().next().unwrap_or_default();
277 let mut parts = first_line.split_whitespace();
278 let method = parts.next().unwrap_or_default();
279 let path = parts.next().unwrap_or_default();
280 let route = routes
281 .iter()
282 .find(|route| route.method.eq_ignore_ascii_case(method) && route.path == path);
283 let (status, headers, body) = if let Some(route) = route {
284 (
285 route.status,
286 route.headers.clone(),
287 mock_body_to_string(&route.body),
288 )
289 } else {
290 (
291 404,
292 HashMap::new(),
293 serde_json::json!({"error":"not found"}).to_string(),
294 )
295 };
296 let reason = match status {
297 200 => "OK",
298 201 => "Created",
299 204 => "No Content",
300 400 => "Bad Request",
301 401 => "Unauthorized",
302 403 => "Forbidden",
303 404 => "Not Found",
304 500 => "Internal Server Error",
305 _ => "OK",
306 };
307 let mut response = format!(
308 "HTTP/1.1 {} {}\r\nContent-Length: {}\r\nContent-Type: application/json\r\nConnection: close\r\n",
309 status,
310 reason,
311 body.len()
312 );
313 for (key, value) in headers {
314 response.push_str(&format!("{}: {}\r\n", key, value));
315 }
316 response.push_str("\r\n");
317 response.push_str(&body);
318 stream.write_all(response.as_bytes()).await?;
319 Ok(())
320}
321
322fn mock_body_to_string(body: &Value) -> String {
323 if let Some(text) = body.as_str() {
324 text.to_string()
325 } else {
326 serde_json::to_string(body).unwrap_or_else(|_| "null".to_string())
327 }
328}
329
330pub fn build_tool_registry(
331 fixtures: &FixturesConfig,
332 log: RecordingToolLog,
333) -> Result<ToolRegistry> {
334 let builtin = create_builtin_registry();
335 let mut registry = ToolRegistry::new();
336
337 for (id, mock) in &fixtures.tools {
338 registry
339 .register(Arc::new(RecordingTool::new(
340 Arc::new(MockTool::new(id.clone(), mock.clone())),
341 log.clone(),
342 ToolExecutionSource::Mock,
343 )))
344 .map_err(|error| EvalError::Config(error.to_string()))?;
345 }
346
347 for id in builtin.list_ids() {
348 if fixtures.tools.contains_key(&id) {
349 continue;
350 }
351 if let Some(tool) = builtin.get(&id) {
352 registry
353 .register(Arc::new(RecordingTool::new(
354 tool,
355 log.clone(),
356 ToolExecutionSource::Llm,
357 )))
358 .map_err(|error| EvalError::Config(error.to_string()))?;
359 }
360 }
361
362 Ok(registry)
363}
364
365struct MockTool {
367 id: String,
369 config: ToolMockConfig,
371}
372
373impl MockTool {
374 fn new(id: String, config: ToolMockConfig) -> Self {
375 Self { id, config }
376 }
377}
378
379#[async_trait]
380impl Tool for MockTool {
381 fn id(&self) -> &str {
382 &self.id
383 }
384
385 fn name(&self) -> &str {
386 &self.id
387 }
388
389 fn description(&self) -> &str {
390 "Evaluation mock tool"
391 }
392
393 fn input_schema(&self) -> Value {
394 serde_json::json!({"type": "object"})
395 }
396
397 async fn execute(&self, _args: Value) -> ToolResult {
398 let output = if self.config.output.is_string() {
399 self.config.output.as_str().unwrap_or_default().to_string()
400 } else {
401 serde_json::to_string(&self.config.output).unwrap_or_else(|_| "null".to_string())
402 };
403 ToolResult {
404 success: self.config.success,
405 output,
406 metadata: None,
407 }
408 }
409}
410
411struct RecordingTool {
413 inner: Arc<dyn Tool>,
415 log: RecordingToolLog,
417 source: ToolExecutionSource,
419}
420
421impl RecordingTool {
422 fn new(inner: Arc<dyn Tool>, log: RecordingToolLog, source: ToolExecutionSource) -> Self {
423 Self { inner, log, source }
424 }
425}
426
427#[async_trait]
428impl Tool for RecordingTool {
429 fn id(&self) -> &str {
430 self.inner.id()
431 }
432
433 fn name(&self) -> &str {
434 self.inner.name()
435 }
436
437 fn description(&self) -> &str {
438 self.inner.description()
439 }
440
441 fn input_schema(&self) -> Value {
442 self.inner.input_schema()
443 }
444
445 async fn execute(&self, args: Value) -> ToolResult {
446 let started_at = chrono::Utc::now();
447 let start = Instant::now();
448 let result = self.inner.execute(args.clone()).await;
449 let duration_ms = start.elapsed().as_millis() as u64;
450 let output =
451 serde_json::from_str(&result.output).unwrap_or(Value::String(result.output.clone()));
452 self.log.push(ToolExecutionRecord {
453 call_id: uuid::Uuid::new_v4().to_string(),
454 tool_id: self.inner.id().to_string(),
455 requested_name: self.inner.name().to_string(),
456 source: self.source.clone(),
457 state: None,
458 actor_id: None,
459 arguments_original: args.clone(),
460 arguments_executed: args,
461 success: result.success,
462 output: result.success.then_some(output),
463 error: (!result.success).then_some(result.output.clone()),
464 metadata: result
465 .metadata
466 .clone()
467 .map(|m| serde_json::to_value(m).unwrap_or(Value::Null)),
468 started_at,
469 duration_ms,
470 observability_span_id: None,
471 });
472 result
473 }
474}
475
476pub fn build_llm_registry(
477 spec: &AgentSpec,
478 fixtures: &LlmFixtureConfig,
479 base_dir: &Path,
480) -> Result<(LLMRegistry, Option<Arc<dyn LLMProvider>>)> {
481 let mut registry = LLMRegistry::new();
482 let aliases = if spec.llms.is_empty() {
483 vec![(
484 "default".to_string(),
485 spec.llm.as_config().cloned().unwrap_or_default(),
486 )]
487 } else {
488 spec.llms
489 .iter()
490 .map(|(alias, config)| (alias.clone(), config.clone()))
491 .collect()
492 };
493
494 let cassette_records = load_cassette_records(fixtures, base_dir)?;
495 let mut judge_provider = None;
496
497 for (alias, config) in aliases {
498 let fixture_responses = load_fixture_responses_for_alias(fixtures, base_dir, &alias)?;
499 let fixture_error = fixtures.errors_by_alias.get(&alias).cloned();
500 let provider = match fixtures.mode {
501 LlmFixtureMode::Mock => Arc::new(SequenceLLMProvider::new(
502 alias.clone(),
503 fixture_responses.clone(),
504 fixture_error,
505 )) as Arc<dyn LLMProvider>,
506 LlmFixtureMode::Replay => Arc::new(ReplayLLMProvider::new(
507 alias.clone(),
508 config.model.clone(),
509 cassette_records.clone(),
510 fixture_responses.clone(),
511 )) as Arc<dyn LLMProvider>,
512 LlmFixtureMode::Real => build_real_provider(&config)?,
513 LlmFixtureMode::Record => {
514 let inner = build_real_provider(&config)?;
515 let path = fixtures
516 .cassette
517 .as_ref()
518 .map(|p| resolve_path(base_dir, p))
519 .unwrap_or_else(|| base_dir.join("llm_cassette.jsonl"));
520 Arc::new(RecordingLLMProvider::new(
521 inner,
522 alias.clone(),
523 config.model.clone(),
524 path,
525 )) as Arc<dyn LLMProvider>
526 }
527 };
528 if judge_provider.is_none() {
529 judge_provider = Some(provider.clone());
530 }
531 registry.register(alias, provider);
532 }
533
534 let default_alias = spec.llm.get_default_alias();
535 registry.set_default(default_alias);
536 if let Some(router) = spec.llm.get_router_alias() {
537 registry.set_router(router);
538 }
539
540 Ok((registry, judge_provider))
541}
542
543fn build_real_provider(
544 config: &ai_agents_runtime::spec::LLMConfig,
545) -> Result<Arc<dyn LLMProvider>> {
546 use std::str::FromStr;
547 let provider_type = ProviderType::from_str(&config.provider)
548 .map_err(|error| EvalError::Config(error.to_string()))?;
549 let core_config = ai_agents_core::LLMConfig {
550 temperature: Some(config.temperature),
551 max_tokens: Some(config.max_tokens),
552 top_p: config.top_p,
553 top_k: None,
554 frequency_penalty: None,
555 presence_penalty: None,
556 stop_sequences: None,
557 timeout_seconds: config.timeout_seconds,
558 reasoning: config.reasoning,
559 reasoning_effort: config.reasoning_effort.clone(),
560 reasoning_budget_tokens: config.reasoning_budget_tokens,
561 extra: config.extra.clone(),
562 };
563 let base_url = config.base_url.clone().or_else(|| {
564 config
565 .extra
566 .get("base_url")
567 .and_then(Value::as_str)
568 .map(str::to_string)
569 });
570 let api_key = config
571 .api_key_env
572 .as_ref()
573 .and_then(|env| std::env::var(env).ok());
574 let mut provider = UnifiedLLMProvider::from_spec_config(
575 provider_type,
576 &config.model,
577 api_key,
578 base_url,
579 core_config,
580 )
581 .map_err(|error| EvalError::Runtime(error.to_string()))?;
582 if let Some(value) = config.function_calling {
583 provider = provider.with_feature_override(LLMFeature::FunctionCalling, value);
584 }
585 if let Some(value) = config.vision {
586 provider = provider.with_feature_override(LLMFeature::Vision, value);
587 }
588 if let Some(value) = config.json_mode {
589 provider = provider.with_feature_override(LLMFeature::JsonMode, value);
590 }
591 Ok(Arc::new(provider))
592}
593
594fn load_fixture_responses_for_alias(
595 config: &LlmFixtureConfig,
596 base_dir: &Path,
597 alias: &str,
598) -> Result<Vec<LLMResponse>> {
599 let configured = config
600 .responses_by_alias
601 .get(alias)
602 .unwrap_or(&config.responses);
603 let mut responses = Vec::new();
604 for content in configured {
605 responses.push(LLMResponse::new(content.clone(), FinishReason::Stop));
606 }
607 for record in load_cassette_records(config, base_dir)? {
608 if record.alias == alias {
609 responses.push(record.response);
610 }
611 }
612 if responses.is_empty() {
613 responses.push(LLMResponse::new("Mock response", FinishReason::Stop));
614 }
615 Ok(responses)
616}
617
618fn load_cassette_records(
619 config: &LlmFixtureConfig,
620 base_dir: &Path,
621) -> Result<Vec<CassetteRecord>> {
622 let mut records = Vec::new();
623 if let Some(path) = &config.cassette {
624 let resolved = resolve_path(base_dir, path);
625 if resolved.exists() {
626 let content = std::fs::read_to_string(&resolved)?;
627 for line in content.lines().filter(|line| !line.trim().is_empty()) {
628 records.push(serde_json::from_str(line)?);
629 }
630 }
631 }
632 Ok(records)
633}
634
635#[derive(Debug, Clone, Serialize, Deserialize)]
637struct CassetteRecord {
638 alias: String,
640 model: String,
642 request_hash: String,
644 #[serde(default)]
646 request_hash_version: Option<String>,
647 response: LLMResponse,
649}
650
651struct SequenceLLMProvider {
653 responses: Arc<Mutex<Vec<LLMResponse>>>,
655 index: Arc<Mutex<usize>>,
657 error: Option<String>,
659}
660
661struct ReplayLLMProvider {
663 alias: String,
665 model: String,
667 records: Arc<Vec<CassetteRecord>>,
669 responses: SequenceLLMProvider,
671}
672
673impl SequenceLLMProvider {
674 fn new(_name: String, responses: Vec<LLMResponse>, error: Option<String>) -> Self {
675 Self {
676 responses: Arc::new(Mutex::new(responses)),
677 index: Arc::new(Mutex::new(0)),
678 error,
679 }
680 }
681
682 fn next_response(&self) -> LLMResponse {
683 let responses = self.responses.lock();
684 let mut index = self.index.lock();
685 let response = responses
686 .get(*index)
687 .cloned()
688 .or_else(|| responses.last().cloned())
689 .unwrap_or_else(|| LLMResponse::new("Mock response", FinishReason::Stop));
690 if *index + 1 < responses.len() {
691 *index += 1;
692 }
693 response
694 }
695}
696
697impl ReplayLLMProvider {
698 fn new(
699 alias: String,
700 model: String,
701 records: Vec<CassetteRecord>,
702 fallback: Vec<LLMResponse>,
703 ) -> Self {
704 Self {
705 alias,
706 model,
707 records: Arc::new(records),
708 responses: SequenceLLMProvider::new("replay-fallback".to_string(), fallback, None),
709 }
710 }
711
712 fn response_for(&self, messages: &[ChatMessage], config: Option<&LLMConfig>) -> LLMResponse {
713 let request_hash = hash_request(messages, config);
714 if let Some(record) = self.records.iter().find(|record| {
715 record.alias == self.alias
716 && record.request_hash == request_hash
717 && (record.model == self.model || record.model.is_empty())
718 }) {
719 return record.response.clone();
720 }
721 self.responses.next_response()
722 }
723}
724
725#[async_trait]
726impl LLMProvider for ReplayLLMProvider {
727 async fn complete(
728 &self,
729 messages: &[ChatMessage],
730 config: Option<&LLMConfig>,
731 ) -> std::result::Result<LLMResponse, LLMError> {
732 Ok(self.response_for(messages, config))
733 }
734
735 async fn complete_stream(
736 &self,
737 messages: &[ChatMessage],
738 config: Option<&LLMConfig>,
739 ) -> std::result::Result<
740 Box<dyn Stream<Item = std::result::Result<LLMChunk, LLMError>> + Unpin + Send>,
741 LLMError,
742 > {
743 let response = self.response_for(messages, config);
744 Ok(Box::new(futures::stream::iter(chunks_from_response(
745 response,
746 ))))
747 }
748
749 fn provider_name(&self) -> &str {
750 "eval-replay"
751 }
752
753 fn supports(&self, feature: LLMFeature) -> bool {
754 matches!(feature, LLMFeature::Streaming | LLMFeature::SystemMessages)
755 }
756}
757
758#[async_trait]
759impl LLMProvider for SequenceLLMProvider {
760 async fn complete(
761 &self,
762 _messages: &[ChatMessage],
763 _config: Option<&LLMConfig>,
764 ) -> std::result::Result<LLMResponse, LLMError> {
765 if let Some(error) = &self.error {
766 return Err(LLMError::API {
767 message: error.clone(),
768 status: None,
769 });
770 }
771 Ok(self.next_response())
772 }
773
774 async fn complete_stream(
775 &self,
776 _messages: &[ChatMessage],
777 _config: Option<&LLMConfig>,
778 ) -> std::result::Result<
779 Box<dyn Stream<Item = std::result::Result<LLMChunk, LLMError>> + Unpin + Send>,
780 LLMError,
781 > {
782 if let Some(error) = &self.error {
783 return Err(LLMError::API {
784 message: error.clone(),
785 status: None,
786 });
787 }
788 let response = self.next_response();
789 Ok(Box::new(futures::stream::iter(chunks_from_response(
790 response,
791 ))))
792 }
793
794 fn provider_name(&self) -> &str {
795 "eval-sequence"
796 }
797
798 fn supports(&self, feature: LLMFeature) -> bool {
799 matches!(feature, LLMFeature::Streaming | LLMFeature::SystemMessages)
800 }
801}
802
803fn chunks_from_response(response: LLMResponse) -> Vec<std::result::Result<LLMChunk, LLMError>> {
804 let deltas = split_stream_content(&response.content);
805 if deltas.is_empty() {
806 return vec![Ok(LLMChunk::final_chunk(
807 "",
808 response.finish_reason,
809 response.usage,
810 ))];
811 }
812 let last_index = deltas.len() - 1;
813 deltas
814 .into_iter()
815 .enumerate()
816 .map(|(index, delta)| {
817 if index == last_index {
818 Ok(LLMChunk::final_chunk(
819 delta,
820 response.finish_reason.clone(),
821 response.usage.clone(),
822 ))
823 } else {
824 Ok(LLMChunk::new(delta, false))
825 }
826 })
827 .collect()
828}
829
830fn split_stream_content(content: &str) -> Vec<String> {
831 if content.is_empty() {
832 return Vec::new();
833 }
834 let words: Vec<&str> = content.split_whitespace().collect();
835 if words.len() > 1 {
836 return words
837 .into_iter()
838 .enumerate()
839 .map(|(index, word)| {
840 if index == 0 {
841 word.to_string()
842 } else {
843 format!(" {}", word)
844 }
845 })
846 .collect();
847 }
848 let chars: Vec<char> = content.chars().collect();
849 if chars.len() <= 1 {
850 return vec![content.to_string()];
851 }
852 let chunk_size = 4.min(chars.len().div_ceil(2));
853 chars
854 .chunks(chunk_size)
855 .map(|chunk| chunk.iter().collect::<String>())
856 .collect()
857}
858
859struct RecordingLLMProvider {
861 inner: Arc<dyn LLMProvider>,
863 alias: String,
865 model: String,
867 path: PathBuf,
869}
870
871impl RecordingLLMProvider {
872 fn new(inner: Arc<dyn LLMProvider>, alias: String, model: String, path: PathBuf) -> Self {
873 Self {
874 inner,
875 alias,
876 model,
877 path,
878 }
879 }
880}
881
882#[async_trait]
883impl LLMProvider for RecordingLLMProvider {
884 async fn complete(
885 &self,
886 messages: &[ChatMessage],
887 config: Option<&LLMConfig>,
888 ) -> std::result::Result<LLMResponse, LLMError> {
889 let response = self.inner.complete(messages, config).await?;
890 let record = CassetteRecord {
891 alias: self.alias.clone(),
892 model: self.model.clone(),
893 request_hash: hash_request(messages, config),
894 request_hash_version: Some("sha256-v1".to_string()),
895 response: response.clone(),
896 };
897 if let Some(parent) = self.path.parent() {
898 let _ = std::fs::create_dir_all(parent);
899 }
900 if let Ok(mut file) = OpenOptions::new()
901 .create(true)
902 .append(true)
903 .open(&self.path)
904 {
905 let _ = writeln!(
906 file,
907 "{}",
908 serde_json::to_string(&record).unwrap_or_default()
909 );
910 }
911 Ok(response)
912 }
913
914 async fn complete_stream(
915 &self,
916 messages: &[ChatMessage],
917 config: Option<&LLMConfig>,
918 ) -> std::result::Result<
919 Box<dyn Stream<Item = std::result::Result<LLMChunk, LLMError>> + Unpin + Send>,
920 LLMError,
921 > {
922 self.inner.complete_stream(messages, config).await
923 }
924
925 fn provider_name(&self) -> &str {
926 self.inner.provider_name()
927 }
928
929 fn supports(&self, feature: LLMFeature) -> bool {
930 self.inner.supports(feature)
931 }
932}
933
934fn hash_request(messages: &[ChatMessage], config: Option<&LLMConfig>) -> String {
935 let canonical_messages: Vec<Value> = messages
936 .iter()
937 .map(|message| {
938 json!({
939 "role": format!("{:?}", message.role),
940 "content": message.content,
941 "name": message.name,
942 })
943 })
944 .collect();
945 let canonical = json!({
946 "version": "sha256-v1",
947 "messages": canonical_messages,
948 "config": config,
949 });
950 let encoded = serde_json::to_vec(&canonical).unwrap_or_default();
951 let digest = Sha256::digest(encoded);
952 format!("sha256-v1:{:x}", digest)
953}
954
955fn default_status() -> u16 {
956 200
957}
958
959fn default_true() -> bool {
960 true
961}
962
963#[cfg(test)]
964mod tests {
965 use super::*;
966
967 #[test]
968 fn context_file_and_inline_context_merge() {
969 let dir = std::env::temp_dir().join(format!(
970 "ai_agents_eval_fixture_test_{}",
971 uuid::Uuid::new_v4()
972 ));
973 std::fs::create_dir_all(&dir).unwrap();
974 let context_path = dir.join("context.json");
975 std::fs::write(
976 &context_path,
977 r#"{"user":{"tier":"basic"},"channel":"file"}"#,
978 )
979 .unwrap();
980 let config = FixturesConfig {
981 context: Some(serde_json::json!({"channel":"inline","feature":true})),
982 context_file: Some(PathBuf::from("context.json")),
983 ..Default::default()
984 };
985 let context = resolve_fixture_context(&config, &dir).unwrap();
986 assert_eq!(context.get("channel"), Some(&serde_json::json!("inline")));
987 assert_eq!(context.get("feature"), Some(&serde_json::json!(true)));
988 assert!(context.get("user").is_some());
989 let _ = std::fs::remove_dir_all(dir);
990 }
991
992 #[test]
993 fn mock_streaming_splits_response_into_multiple_chunks() {
994 let response = LLMResponse::new(
995 "Streaming hello from the mocked provider.".to_string(),
996 FinishReason::Stop,
997 );
998 let chunks = chunks_from_response(response);
999 assert!(chunks.len() > 1);
1000 let mut reconstructed = String::new();
1001 for (index, chunk) in chunks.into_iter().enumerate() {
1002 let chunk = chunk.unwrap();
1003 if index == 0 {
1004 assert!(!chunk.is_final);
1005 }
1006 reconstructed.push_str(&chunk.delta);
1007 if chunk.is_final {
1008 assert!(chunk.finish_reason.is_some());
1009 }
1010 }
1011 assert_eq!(reconstructed, "Streaming hello from the mocked provider.");
1012 }
1013
1014 #[test]
1015 fn mock_streaming_splits_single_word_response() {
1016 let chunks = split_stream_content("Hello");
1017 assert!(chunks.len() > 1);
1018 assert_eq!(chunks.join(""), "Hello");
1019 }
1020
1021 #[tokio::test]
1022 async fn mock_server_serves_configured_route() {
1023 let config = MockServerConfig {
1024 enabled: true,
1025 port: None,
1026 routes: vec![serde_json::json!({
1027 "method":"GET",
1028 "path":"/ok",
1029 "status":200,
1030 "body":{"ok":true}
1031 })],
1032 };
1033 let server = start_mock_server(Some(&config)).await.unwrap().unwrap();
1034 let context = server.context();
1035 let base_url = context
1036 .get("mock_server")
1037 .and_then(|value| value.get("base_url"))
1038 .and_then(Value::as_str)
1039 .unwrap()
1040 .trim_start_matches("http://")
1041 .to_string();
1042 let mut stream = tokio::net::TcpStream::connect(base_url).await.unwrap();
1043 stream
1044 .write_all(b"GET /ok HTTP/1.1\r\nHost: localhost\r\n\r\n")
1045 .await
1046 .unwrap();
1047 let mut response = String::new();
1048 stream.read_to_string(&mut response).await.unwrap();
1049 assert!(response.contains("200 OK"));
1050 assert!(response.contains("{\"ok\":true}"));
1051 }
1052}