Skip to main content

ai_agents_eval/
fixtures.rs

1use std::collections::HashMap;
2use std::fs::OpenOptions;
3use std::io::Write;
4use std::path::{Path, PathBuf};
5use std::sync::Arc;
6use std::time::Instant;
7
8use tokio::io::{AsyncReadExt, AsyncWriteExt};
9use tokio::net::TcpListener;
10use tokio::task::JoinHandle;
11
12use ai_agents_core::{
13    ChatMessage, LLMChunk, LLMConfig, LLMError, LLMFeature, LLMProvider, LLMResponse, Tool,
14    ToolResult,
15};
16use ai_agents_llm::providers::{ProviderType, UnifiedLLMProvider};
17use ai_agents_llm::{FinishReason, LLMRegistry};
18use ai_agents_runtime::spec::AgentSpec;
19use ai_agents_tools::{ToolRegistry, create_builtin_registry};
20use async_trait::async_trait;
21use futures::Stream;
22use parking_lot::Mutex;
23use serde::{Deserialize, Serialize};
24use serde_json::{Value, json};
25use sha2::{Digest, Sha256};
26
27use crate::evidence::{ToolExecutionRecord, ToolExecutionSource};
28use crate::{EvalError, Result};
29
30/// Fixture configuration used to replace external dependencies during eval.
31#[derive(Debug, Clone, Serialize, Deserialize, Default)]
32pub struct FixturesConfig {
33    /// Runtime or fixture context value.
34    #[serde(default)]
35    pub context: Option<Value>,
36    /// JSON context file resolved relative to the suite file.
37    #[serde(default)]
38    pub context_file: Option<PathBuf>,
39    /// Mock tool definitions keyed by tool ID.
40    #[serde(default)]
41    pub tools: HashMap<String, ToolMockConfig>,
42    /// Optional LLM alias or provider used for judge calls.
43    #[serde(default)]
44    pub llm: LlmFixtureConfig,
45    /// Optional local HTTP mock server configuration.
46    #[serde(default)]
47    pub mock_server: Option<MockServerConfig>,
48}
49
50/// Static output configuration for an eval mock tool.
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct ToolMockConfig {
53    /// Whether the operation succeeded.
54    #[serde(default = "default_true")]
55    pub success: bool,
56    /// Directory where output artifacts are written.
57    #[serde(default)]
58    pub output: Value,
59}
60
61impl Default for ToolMockConfig {
62    fn default() -> Self {
63        Self {
64            success: true,
65            output: Value::Null,
66        }
67    }
68}
69
70/// LLM fixture mode and data used by the eval runner.
71#[derive(Debug, Clone, Serialize, Deserialize, Default)]
72pub struct LlmFixtureConfig {
73    /// LLM fixture mode used for configured aliases.
74    #[serde(default)]
75    pub mode: LlmFixtureMode,
76    /// Optional cassette JSONL file for replay or record mode.
77    #[serde(default)]
78    pub cassette: Option<PathBuf>,
79    /// Ordered text responses used by mock mode and fallback replay.
80    #[serde(default)]
81    pub responses: Vec<String>,
82    /// Per-LLM alias ordered responses for deterministic multi-branch evals.
83    #[serde(default)]
84    pub responses_by_alias: HashMap<String, Vec<String>>,
85    /// Per-LLM alias errors for deterministic failure-path evals.
86    #[serde(default)]
87    pub errors_by_alias: HashMap<String, String>,
88}
89
90/// LLM fixture strategy used while building eval providers.
91#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default, PartialEq, Eq)]
92#[serde(rename_all = "snake_case")]
93pub enum LlmFixtureMode {
94    #[default]
95    Real,
96    Mock,
97    Replay,
98    Record,
99}
100
101/// Local HTTP mock server configuration for eval scenarios.
102#[derive(Debug, Clone, Serialize, Deserialize, Default)]
103pub struct MockServerConfig {
104    /// Whether this feature is enabled.
105    #[serde(default)]
106    pub enabled: bool,
107    /// Optional fixed port. Zero or none requests a dynamic port.
108    #[serde(default)]
109    pub port: Option<u16>,
110    /// Route definitions served by the mock server.
111    #[serde(default)]
112    pub routes: Vec<Value>,
113}
114
115/// One route served by the lightweight eval mock server.
116#[derive(Debug, Clone, Deserialize)]
117struct MockRoute {
118    /// HTTP method matched by this route.
119    method: String,
120    /// Path used for file lookup, HTTP routing, or dot-path checks.
121    path: String,
122    /// Final or normalized status value.
123    #[serde(default = "default_status")]
124    status: u16,
125    /// Extra response headers returned by this route.
126    #[serde(default)]
127    headers: HashMap<String, String>,
128    /// JSON or string body returned by this route.
129    #[serde(default)]
130    body: Value,
131}
132
133/// Running mock server handle that stops the server on drop.
134pub struct MockServerHandle {
135    /// Base URL injected into eval context.
136    base_url: String,
137    /// Background accept loop for the mock server.
138    task: JoinHandle<()>,
139}
140
141impl MockServerHandle {
142    pub fn context(&self) -> HashMap<String, Value> {
143        let mut context = HashMap::new();
144        context.insert(
145            "mock_server".to_string(),
146            serde_json::json!({"base_url": self.base_url}),
147        );
148        context
149    }
150}
151
152impl Drop for MockServerHandle {
153    fn drop(&mut self) {
154        self.task.abort();
155    }
156}
157
158/// Shared in-memory log of tool executions for one attempt.
159#[derive(Clone, Default)]
160pub struct RecordingToolLog {
161    /// Wrapped implementation or shared storage.
162    inner: Arc<Mutex<Vec<ToolExecutionRecord>>>,
163}
164
165impl RecordingToolLog {
166    pub fn new() -> Self {
167        Self::default()
168    }
169
170    pub fn len(&self) -> usize {
171        self.inner.lock().len()
172    }
173
174    pub fn push(&self, record: ToolExecutionRecord) {
175        self.inner.lock().push(record);
176    }
177
178    pub fn records_since(&self, index: usize) -> Vec<ToolExecutionRecord> {
179        self.inner.lock().iter().skip(index).cloned().collect()
180    }
181}
182
183pub fn resolve_fixture_context(
184    config: &FixturesConfig,
185    base_dir: &Path,
186) -> Result<HashMap<String, Value>> {
187    let mut result = HashMap::new();
188    if let Some(path) = &config.context_file {
189        let resolved = resolve_path(base_dir, path);
190        let content = std::fs::read_to_string(&resolved).map_err(|error| {
191            EvalError::Config(format!(
192                "failed to read context_file '{}': {}",
193                resolved.display(),
194                error
195            ))
196        })?;
197        let value: Value = serde_json::from_str(&content).map_err(|error| {
198            EvalError::Config(format!(
199                "failed to parse context_file '{}': {}",
200                resolved.display(),
201                error
202            ))
203        })?;
204        merge_object_into_map(&mut result, value)?;
205    }
206    if let Some(value) = &config.context {
207        merge_object_into_map(&mut result, value.clone())?;
208    }
209    Ok(result)
210}
211
212fn merge_object_into_map(target: &mut HashMap<String, Value>, value: Value) -> Result<()> {
213    let Value::Object(map) = value else {
214        return Err(EvalError::Config(
215            "fixture context must be a JSON object".into(),
216        ));
217    };
218    for (key, value) in map {
219        target.insert(key, value);
220    }
221    Ok(())
222}
223
224fn resolve_path(base_dir: &Path, path: &Path) -> PathBuf {
225    if path.is_absolute() {
226        path.to_path_buf()
227    } else {
228        base_dir.join(path)
229    }
230}
231
232pub async fn start_mock_server(
233    config: Option<&MockServerConfig>,
234) -> Result<Option<MockServerHandle>> {
235    let Some(config) = config else {
236        return Ok(None);
237    };
238    if !config.enabled {
239        return Ok(None);
240    }
241    let routes = config
242        .routes
243        .iter()
244        .cloned()
245        .map(serde_json::from_value::<MockRoute>)
246        .collect::<std::result::Result<Vec<_>, _>>()?;
247    let port = config.port.unwrap_or(0);
248    let listener = TcpListener::bind(("127.0.0.1", port))
249        .await
250        .map_err(|error| EvalError::Runtime(format!("failed to start mock server: {}", error)))?;
251    let addr = listener.local_addr().map_err(|error| {
252        EvalError::Runtime(format!("failed to read mock server addr: {}", error))
253    })?;
254    let base_url = format!("http://{}", addr);
255    let task = tokio::spawn(async move {
256        loop {
257            let Ok((stream, _)) = listener.accept().await else {
258                break;
259            };
260            let routes = routes.clone();
261            tokio::spawn(async move {
262                let _ = handle_mock_connection(stream, routes).await;
263            });
264        }
265    });
266    Ok(Some(MockServerHandle { base_url, task }))
267}
268
269async fn handle_mock_connection(
270    mut stream: tokio::net::TcpStream,
271    routes: Vec<MockRoute>,
272) -> std::io::Result<()> {
273    let mut buffer = vec![0_u8; 8192];
274    let read = stream.read(&mut buffer).await?;
275    let request = String::from_utf8_lossy(&buffer[..read]);
276    let first_line = request.lines().next().unwrap_or_default();
277    let mut parts = first_line.split_whitespace();
278    let method = parts.next().unwrap_or_default();
279    let path = parts.next().unwrap_or_default();
280    let route = routes
281        .iter()
282        .find(|route| route.method.eq_ignore_ascii_case(method) && route.path == path);
283    let (status, headers, body) = if let Some(route) = route {
284        (
285            route.status,
286            route.headers.clone(),
287            mock_body_to_string(&route.body),
288        )
289    } else {
290        (
291            404,
292            HashMap::new(),
293            serde_json::json!({"error":"not found"}).to_string(),
294        )
295    };
296    let reason = match status {
297        200 => "OK",
298        201 => "Created",
299        204 => "No Content",
300        400 => "Bad Request",
301        401 => "Unauthorized",
302        403 => "Forbidden",
303        404 => "Not Found",
304        500 => "Internal Server Error",
305        _ => "OK",
306    };
307    let mut response = format!(
308        "HTTP/1.1 {} {}\r\nContent-Length: {}\r\nContent-Type: application/json\r\nConnection: close\r\n",
309        status,
310        reason,
311        body.len()
312    );
313    for (key, value) in headers {
314        response.push_str(&format!("{}: {}\r\n", key, value));
315    }
316    response.push_str("\r\n");
317    response.push_str(&body);
318    stream.write_all(response.as_bytes()).await?;
319    Ok(())
320}
321
322fn mock_body_to_string(body: &Value) -> String {
323    if let Some(text) = body.as_str() {
324        text.to_string()
325    } else {
326        serde_json::to_string(body).unwrap_or_else(|_| "null".to_string())
327    }
328}
329
330pub fn build_tool_registry(
331    fixtures: &FixturesConfig,
332    log: RecordingToolLog,
333) -> Result<ToolRegistry> {
334    let builtin = create_builtin_registry();
335    let mut registry = ToolRegistry::new();
336
337    for (id, mock) in &fixtures.tools {
338        registry
339            .register(Arc::new(RecordingTool::new(
340                Arc::new(MockTool::new(id.clone(), mock.clone())),
341                log.clone(),
342                ToolExecutionSource::Mock,
343            )))
344            .map_err(|error| EvalError::Config(error.to_string()))?;
345    }
346
347    for id in builtin.list_ids() {
348        if fixtures.tools.contains_key(&id) {
349            continue;
350        }
351        if let Some(tool) = builtin.get(&id) {
352            registry
353                .register(Arc::new(RecordingTool::new(
354                    tool,
355                    log.clone(),
356                    ToolExecutionSource::Llm,
357                )))
358                .map_err(|error| EvalError::Config(error.to_string()))?;
359        }
360    }
361
362    Ok(registry)
363}
364
365/// Tool implementation returning a configured eval fixture result.
366struct MockTool {
367    /// Stable identifier for this item.
368    id: String,
369    /// Configuration used by this component.
370    config: ToolMockConfig,
371}
372
373impl MockTool {
374    fn new(id: String, config: ToolMockConfig) -> Self {
375        Self { id, config }
376    }
377}
378
379#[async_trait]
380impl Tool for MockTool {
381    fn id(&self) -> &str {
382        &self.id
383    }
384
385    fn name(&self) -> &str {
386        &self.id
387    }
388
389    fn description(&self) -> &str {
390        "Evaluation mock tool"
391    }
392
393    fn input_schema(&self) -> Value {
394        serde_json::json!({"type": "object"})
395    }
396
397    async fn execute(&self, _args: Value) -> ToolResult {
398        let output = if self.config.output.is_string() {
399            self.config.output.as_str().unwrap_or_default().to_string()
400        } else {
401            serde_json::to_string(&self.config.output).unwrap_or_else(|_| "null".to_string())
402        };
403        ToolResult {
404            success: self.config.success,
405            output,
406            metadata: None,
407        }
408    }
409}
410
411/// Tool wrapper that records calls before returning the inner result.
412struct RecordingTool {
413    /// Wrapped implementation or shared storage.
414    inner: Arc<dyn Tool>,
415    /// Shared log receiving execution records.
416    log: RecordingToolLog,
417    /// Source category assigned to this execution.
418    source: ToolExecutionSource,
419}
420
421impl RecordingTool {
422    fn new(inner: Arc<dyn Tool>, log: RecordingToolLog, source: ToolExecutionSource) -> Self {
423        Self { inner, log, source }
424    }
425}
426
427#[async_trait]
428impl Tool for RecordingTool {
429    fn id(&self) -> &str {
430        self.inner.id()
431    }
432
433    fn name(&self) -> &str {
434        self.inner.name()
435    }
436
437    fn description(&self) -> &str {
438        self.inner.description()
439    }
440
441    fn input_schema(&self) -> Value {
442        self.inner.input_schema()
443    }
444
445    async fn execute(&self, args: Value) -> ToolResult {
446        let started_at = chrono::Utc::now();
447        let start = Instant::now();
448        let result = self.inner.execute(args.clone()).await;
449        let duration_ms = start.elapsed().as_millis() as u64;
450        let output =
451            serde_json::from_str(&result.output).unwrap_or(Value::String(result.output.clone()));
452        self.log.push(ToolExecutionRecord {
453            call_id: uuid::Uuid::new_v4().to_string(),
454            tool_id: self.inner.id().to_string(),
455            requested_name: self.inner.name().to_string(),
456            source: self.source.clone(),
457            state: None,
458            actor_id: None,
459            arguments_original: args.clone(),
460            arguments_executed: args,
461            success: result.success,
462            output: result.success.then_some(output),
463            error: (!result.success).then_some(result.output.clone()),
464            metadata: result
465                .metadata
466                .clone()
467                .map(|m| serde_json::to_value(m).unwrap_or(Value::Null)),
468            started_at,
469            duration_ms,
470            observability_span_id: None,
471        });
472        result
473    }
474}
475
476pub fn build_llm_registry(
477    spec: &AgentSpec,
478    fixtures: &LlmFixtureConfig,
479    base_dir: &Path,
480) -> Result<(LLMRegistry, Option<Arc<dyn LLMProvider>>)> {
481    let mut registry = LLMRegistry::new();
482    let aliases = if spec.llms.is_empty() {
483        vec![(
484            "default".to_string(),
485            spec.llm.as_config().cloned().unwrap_or_default(),
486        )]
487    } else {
488        spec.llms
489            .iter()
490            .map(|(alias, config)| (alias.clone(), config.clone()))
491            .collect()
492    };
493
494    let cassette_records = load_cassette_records(fixtures, base_dir)?;
495    let mut judge_provider = None;
496
497    for (alias, config) in aliases {
498        let fixture_responses = load_fixture_responses_for_alias(fixtures, base_dir, &alias)?;
499        let fixture_error = fixtures.errors_by_alias.get(&alias).cloned();
500        let provider = match fixtures.mode {
501            LlmFixtureMode::Mock => Arc::new(SequenceLLMProvider::new(
502                alias.clone(),
503                fixture_responses.clone(),
504                fixture_error,
505            )) as Arc<dyn LLMProvider>,
506            LlmFixtureMode::Replay => Arc::new(ReplayLLMProvider::new(
507                alias.clone(),
508                config.model.clone(),
509                cassette_records.clone(),
510                fixture_responses.clone(),
511            )) as Arc<dyn LLMProvider>,
512            LlmFixtureMode::Real => build_real_provider(&config)?,
513            LlmFixtureMode::Record => {
514                let inner = build_real_provider(&config)?;
515                let path = fixtures
516                    .cassette
517                    .as_ref()
518                    .map(|p| resolve_path(base_dir, p))
519                    .unwrap_or_else(|| base_dir.join("llm_cassette.jsonl"));
520                Arc::new(RecordingLLMProvider::new(
521                    inner,
522                    alias.clone(),
523                    config.model.clone(),
524                    path,
525                )) as Arc<dyn LLMProvider>
526            }
527        };
528        if judge_provider.is_none() {
529            judge_provider = Some(provider.clone());
530        }
531        registry.register(alias, provider);
532    }
533
534    let default_alias = spec.llm.get_default_alias();
535    registry.set_default(default_alias);
536    if let Some(router) = spec.llm.get_router_alias() {
537        registry.set_router(router);
538    }
539
540    Ok((registry, judge_provider))
541}
542
543fn build_real_provider(
544    config: &ai_agents_runtime::spec::LLMConfig,
545) -> Result<Arc<dyn LLMProvider>> {
546    use std::str::FromStr;
547    let provider_type = ProviderType::from_str(&config.provider)
548        .map_err(|error| EvalError::Config(error.to_string()))?;
549    let core_config = ai_agents_core::LLMConfig {
550        temperature: Some(config.temperature),
551        max_tokens: Some(config.max_tokens),
552        top_p: config.top_p,
553        top_k: None,
554        frequency_penalty: None,
555        presence_penalty: None,
556        stop_sequences: None,
557        timeout_seconds: config.timeout_seconds,
558        reasoning: config.reasoning,
559        reasoning_effort: config.reasoning_effort.clone(),
560        reasoning_budget_tokens: config.reasoning_budget_tokens,
561        extra: config.extra.clone(),
562    };
563    let base_url = config.base_url.clone().or_else(|| {
564        config
565            .extra
566            .get("base_url")
567            .and_then(Value::as_str)
568            .map(str::to_string)
569    });
570    let api_key = config
571        .api_key_env
572        .as_ref()
573        .and_then(|env| std::env::var(env).ok());
574    let mut provider = UnifiedLLMProvider::from_spec_config(
575        provider_type,
576        &config.model,
577        api_key,
578        base_url,
579        core_config,
580    )
581    .map_err(|error| EvalError::Runtime(error.to_string()))?;
582    if let Some(value) = config.function_calling {
583        provider = provider.with_feature_override(LLMFeature::FunctionCalling, value);
584    }
585    if let Some(value) = config.vision {
586        provider = provider.with_feature_override(LLMFeature::Vision, value);
587    }
588    if let Some(value) = config.json_mode {
589        provider = provider.with_feature_override(LLMFeature::JsonMode, value);
590    }
591    Ok(Arc::new(provider))
592}
593
594fn load_fixture_responses_for_alias(
595    config: &LlmFixtureConfig,
596    base_dir: &Path,
597    alias: &str,
598) -> Result<Vec<LLMResponse>> {
599    let configured = config
600        .responses_by_alias
601        .get(alias)
602        .unwrap_or(&config.responses);
603    let mut responses = Vec::new();
604    for content in configured {
605        responses.push(LLMResponse::new(content.clone(), FinishReason::Stop));
606    }
607    for record in load_cassette_records(config, base_dir)? {
608        if record.alias == alias {
609            responses.push(record.response);
610        }
611    }
612    if responses.is_empty() {
613        responses.push(LLMResponse::new("Mock response", FinishReason::Stop));
614    }
615    Ok(responses)
616}
617
618fn load_cassette_records(
619    config: &LlmFixtureConfig,
620    base_dir: &Path,
621) -> Result<Vec<CassetteRecord>> {
622    let mut records = Vec::new();
623    if let Some(path) = &config.cassette {
624        let resolved = resolve_path(base_dir, path);
625        if resolved.exists() {
626            let content = std::fs::read_to_string(&resolved)?;
627            for line in content.lines().filter(|line| !line.trim().is_empty()) {
628                records.push(serde_json::from_str(line)?);
629            }
630        }
631    }
632    Ok(records)
633}
634
635/// One recorded LLM response used by replay and record modes.
636#[derive(Debug, Clone, Serialize, Deserialize)]
637struct CassetteRecord {
638    /// LLM alias used for this provider or cassette record.
639    alias: String,
640    /// Model or relationship model name.
641    model: String,
642    /// Stable hash of messages and request config.
643    request_hash: String,
644    /// Hash format version.
645    #[serde(default)]
646    request_hash_version: Option<String>,
647    /// Assistant response text or redacted output value.
648    response: LLMResponse,
649}
650
651/// Deterministic LLM provider returning fixture responses in order.
652struct SequenceLLMProvider {
653    /// Ordered text responses used by mock mode and fallback replay.
654    responses: Arc<Mutex<Vec<LLMResponse>>>,
655    /// Zero-based turn index within the scenario.
656    index: Arc<Mutex<usize>>,
657    /// Optional deterministic provider error.
658    error: Option<String>,
659}
660
661/// LLM provider replaying cassette records by request hash.
662struct ReplayLLMProvider {
663    /// LLM alias used for this provider or cassette record.
664    alias: String,
665    /// Model or relationship model name.
666    model: String,
667    /// Cassette records available for hash matching.
668    records: Arc<Vec<CassetteRecord>>,
669    /// Ordered text responses used by mock mode and fallback replay.
670    responses: SequenceLLMProvider,
671}
672
673impl SequenceLLMProvider {
674    fn new(_name: String, responses: Vec<LLMResponse>, error: Option<String>) -> Self {
675        Self {
676            responses: Arc::new(Mutex::new(responses)),
677            index: Arc::new(Mutex::new(0)),
678            error,
679        }
680    }
681
682    fn next_response(&self) -> LLMResponse {
683        let responses = self.responses.lock();
684        let mut index = self.index.lock();
685        let response = responses
686            .get(*index)
687            .cloned()
688            .or_else(|| responses.last().cloned())
689            .unwrap_or_else(|| LLMResponse::new("Mock response", FinishReason::Stop));
690        if *index + 1 < responses.len() {
691            *index += 1;
692        }
693        response
694    }
695}
696
697impl ReplayLLMProvider {
698    fn new(
699        alias: String,
700        model: String,
701        records: Vec<CassetteRecord>,
702        fallback: Vec<LLMResponse>,
703    ) -> Self {
704        Self {
705            alias,
706            model,
707            records: Arc::new(records),
708            responses: SequenceLLMProvider::new("replay-fallback".to_string(), fallback, None),
709        }
710    }
711
712    fn response_for(&self, messages: &[ChatMessage], config: Option<&LLMConfig>) -> LLMResponse {
713        let request_hash = hash_request(messages, config);
714        if let Some(record) = self.records.iter().find(|record| {
715            record.alias == self.alias
716                && record.request_hash == request_hash
717                && (record.model == self.model || record.model.is_empty())
718        }) {
719            return record.response.clone();
720        }
721        self.responses.next_response()
722    }
723}
724
725#[async_trait]
726impl LLMProvider for ReplayLLMProvider {
727    async fn complete(
728        &self,
729        messages: &[ChatMessage],
730        config: Option<&LLMConfig>,
731    ) -> std::result::Result<LLMResponse, LLMError> {
732        Ok(self.response_for(messages, config))
733    }
734
735    async fn complete_stream(
736        &self,
737        messages: &[ChatMessage],
738        config: Option<&LLMConfig>,
739    ) -> std::result::Result<
740        Box<dyn Stream<Item = std::result::Result<LLMChunk, LLMError>> + Unpin + Send>,
741        LLMError,
742    > {
743        let response = self.response_for(messages, config);
744        Ok(Box::new(futures::stream::iter(chunks_from_response(
745            response,
746        ))))
747    }
748
749    fn provider_name(&self) -> &str {
750        "eval-replay"
751    }
752
753    fn supports(&self, feature: LLMFeature) -> bool {
754        matches!(feature, LLMFeature::Streaming | LLMFeature::SystemMessages)
755    }
756}
757
758#[async_trait]
759impl LLMProvider for SequenceLLMProvider {
760    async fn complete(
761        &self,
762        _messages: &[ChatMessage],
763        _config: Option<&LLMConfig>,
764    ) -> std::result::Result<LLMResponse, LLMError> {
765        if let Some(error) = &self.error {
766            return Err(LLMError::API {
767                message: error.clone(),
768                status: None,
769            });
770        }
771        Ok(self.next_response())
772    }
773
774    async fn complete_stream(
775        &self,
776        _messages: &[ChatMessage],
777        _config: Option<&LLMConfig>,
778    ) -> std::result::Result<
779        Box<dyn Stream<Item = std::result::Result<LLMChunk, LLMError>> + Unpin + Send>,
780        LLMError,
781    > {
782        if let Some(error) = &self.error {
783            return Err(LLMError::API {
784                message: error.clone(),
785                status: None,
786            });
787        }
788        let response = self.next_response();
789        Ok(Box::new(futures::stream::iter(chunks_from_response(
790            response,
791        ))))
792    }
793
794    fn provider_name(&self) -> &str {
795        "eval-sequence"
796    }
797
798    fn supports(&self, feature: LLMFeature) -> bool {
799        matches!(feature, LLMFeature::Streaming | LLMFeature::SystemMessages)
800    }
801}
802
803fn chunks_from_response(response: LLMResponse) -> Vec<std::result::Result<LLMChunk, LLMError>> {
804    let deltas = split_stream_content(&response.content);
805    if deltas.is_empty() {
806        return vec![Ok(LLMChunk::final_chunk(
807            "",
808            response.finish_reason,
809            response.usage,
810        ))];
811    }
812    let last_index = deltas.len() - 1;
813    deltas
814        .into_iter()
815        .enumerate()
816        .map(|(index, delta)| {
817            if index == last_index {
818                Ok(LLMChunk::final_chunk(
819                    delta,
820                    response.finish_reason.clone(),
821                    response.usage.clone(),
822                ))
823            } else {
824                Ok(LLMChunk::new(delta, false))
825            }
826        })
827        .collect()
828}
829
830fn split_stream_content(content: &str) -> Vec<String> {
831    if content.is_empty() {
832        return Vec::new();
833    }
834    let words: Vec<&str> = content.split_whitespace().collect();
835    if words.len() > 1 {
836        return words
837            .into_iter()
838            .enumerate()
839            .map(|(index, word)| {
840                if index == 0 {
841                    word.to_string()
842                } else {
843                    format!(" {}", word)
844                }
845            })
846            .collect();
847    }
848    let chars: Vec<char> = content.chars().collect();
849    if chars.len() <= 1 {
850        return vec![content.to_string()];
851    }
852    let chunk_size = 4.min(chars.len().div_ceil(2));
853    chars
854        .chunks(chunk_size)
855        .map(|chunk| chunk.iter().collect::<String>())
856        .collect()
857}
858
859/// LLM provider wrapper appending responses to a cassette file.
860struct RecordingLLMProvider {
861    /// Wrapped implementation or shared storage.
862    inner: Arc<dyn LLMProvider>,
863    /// LLM alias used for this provider or cassette record.
864    alias: String,
865    /// Model or relationship model name.
866    model: String,
867    /// Path used for file lookup, HTTP routing, or dot-path checks.
868    path: PathBuf,
869}
870
871impl RecordingLLMProvider {
872    fn new(inner: Arc<dyn LLMProvider>, alias: String, model: String, path: PathBuf) -> Self {
873        Self {
874            inner,
875            alias,
876            model,
877            path,
878        }
879    }
880}
881
882#[async_trait]
883impl LLMProvider for RecordingLLMProvider {
884    async fn complete(
885        &self,
886        messages: &[ChatMessage],
887        config: Option<&LLMConfig>,
888    ) -> std::result::Result<LLMResponse, LLMError> {
889        let response = self.inner.complete(messages, config).await?;
890        let record = CassetteRecord {
891            alias: self.alias.clone(),
892            model: self.model.clone(),
893            request_hash: hash_request(messages, config),
894            request_hash_version: Some("sha256-v1".to_string()),
895            response: response.clone(),
896        };
897        if let Some(parent) = self.path.parent() {
898            let _ = std::fs::create_dir_all(parent);
899        }
900        if let Ok(mut file) = OpenOptions::new()
901            .create(true)
902            .append(true)
903            .open(&self.path)
904        {
905            let _ = writeln!(
906                file,
907                "{}",
908                serde_json::to_string(&record).unwrap_or_default()
909            );
910        }
911        Ok(response)
912    }
913
914    async fn complete_stream(
915        &self,
916        messages: &[ChatMessage],
917        config: Option<&LLMConfig>,
918    ) -> std::result::Result<
919        Box<dyn Stream<Item = std::result::Result<LLMChunk, LLMError>> + Unpin + Send>,
920        LLMError,
921    > {
922        self.inner.complete_stream(messages, config).await
923    }
924
925    fn provider_name(&self) -> &str {
926        self.inner.provider_name()
927    }
928
929    fn supports(&self, feature: LLMFeature) -> bool {
930        self.inner.supports(feature)
931    }
932}
933
934fn hash_request(messages: &[ChatMessage], config: Option<&LLMConfig>) -> String {
935    let canonical_messages: Vec<Value> = messages
936        .iter()
937        .map(|message| {
938            json!({
939                "role": format!("{:?}", message.role),
940                "content": message.content,
941                "name": message.name,
942            })
943        })
944        .collect();
945    let canonical = json!({
946        "version": "sha256-v1",
947        "messages": canonical_messages,
948        "config": config,
949    });
950    let encoded = serde_json::to_vec(&canonical).unwrap_or_default();
951    let digest = Sha256::digest(encoded);
952    format!("sha256-v1:{:x}", digest)
953}
954
955fn default_status() -> u16 {
956    200
957}
958
959fn default_true() -> bool {
960    true
961}
962
963#[cfg(test)]
964mod tests {
965    use super::*;
966
967    #[test]
968    fn context_file_and_inline_context_merge() {
969        let dir = std::env::temp_dir().join(format!(
970            "ai_agents_eval_fixture_test_{}",
971            uuid::Uuid::new_v4()
972        ));
973        std::fs::create_dir_all(&dir).unwrap();
974        let context_path = dir.join("context.json");
975        std::fs::write(
976            &context_path,
977            r#"{"user":{"tier":"basic"},"channel":"file"}"#,
978        )
979        .unwrap();
980        let config = FixturesConfig {
981            context: Some(serde_json::json!({"channel":"inline","feature":true})),
982            context_file: Some(PathBuf::from("context.json")),
983            ..Default::default()
984        };
985        let context = resolve_fixture_context(&config, &dir).unwrap();
986        assert_eq!(context.get("channel"), Some(&serde_json::json!("inline")));
987        assert_eq!(context.get("feature"), Some(&serde_json::json!(true)));
988        assert!(context.get("user").is_some());
989        let _ = std::fs::remove_dir_all(dir);
990    }
991
992    #[test]
993    fn mock_streaming_splits_response_into_multiple_chunks() {
994        let response = LLMResponse::new(
995            "Streaming hello from the mocked provider.".to_string(),
996            FinishReason::Stop,
997        );
998        let chunks = chunks_from_response(response);
999        assert!(chunks.len() > 1);
1000        let mut reconstructed = String::new();
1001        for (index, chunk) in chunks.into_iter().enumerate() {
1002            let chunk = chunk.unwrap();
1003            if index == 0 {
1004                assert!(!chunk.is_final);
1005            }
1006            reconstructed.push_str(&chunk.delta);
1007            if chunk.is_final {
1008                assert!(chunk.finish_reason.is_some());
1009            }
1010        }
1011        assert_eq!(reconstructed, "Streaming hello from the mocked provider.");
1012    }
1013
1014    #[test]
1015    fn mock_streaming_splits_single_word_response() {
1016        let chunks = split_stream_content("Hello");
1017        assert!(chunks.len() > 1);
1018        assert_eq!(chunks.join(""), "Hello");
1019    }
1020
1021    #[tokio::test]
1022    async fn mock_server_serves_configured_route() {
1023        let config = MockServerConfig {
1024            enabled: true,
1025            port: None,
1026            routes: vec![serde_json::json!({
1027                "method":"GET",
1028                "path":"/ok",
1029                "status":200,
1030                "body":{"ok":true}
1031            })],
1032        };
1033        let server = start_mock_server(Some(&config)).await.unwrap().unwrap();
1034        let context = server.context();
1035        let base_url = context
1036            .get("mock_server")
1037            .and_then(|value| value.get("base_url"))
1038            .and_then(Value::as_str)
1039            .unwrap()
1040            .trim_start_matches("http://")
1041            .to_string();
1042        let mut stream = tokio::net::TcpStream::connect(base_url).await.unwrap();
1043        stream
1044            .write_all(b"GET /ok HTTP/1.1\r\nHost: localhost\r\n\r\n")
1045            .await
1046            .unwrap();
1047        let mut response = String::new();
1048        stream.read_to_string(&mut response).await.unwrap();
1049        assert!(response.contains("200 OK"));
1050        assert!(response.contains("{\"ok\":true}"));
1051    }
1052}