agent_chain_core/language_models/
fake.rs

1//! Fake LLMs for testing purposes.
2//!
3//! This module provides fake LLM implementations that can be used
4//! for testing without making actual API calls.
5//! Mirrors `langchain_core.language_models.fake`.
6
7use std::collections::HashMap;
8use std::sync::atomic::{AtomicUsize, Ordering};
9use std::time::Duration;
10
11use async_trait::async_trait;
12use serde_json::Value;
13
14use super::base::{BaseLanguageModel, LanguageModelConfig, LanguageModelInput};
15use super::llms::{BaseLLM, LLM, LLMConfig, LLMStream};
16use crate::caches::BaseCache;
17use crate::callbacks::{CallbackManagerForLLMRun, Callbacks};
18use crate::error::Result;
19use crate::outputs::{Generation, GenerationChunk, GenerationType, LLMResult};
20
21/// Fake LLM for testing purposes.
22///
23/// Returns responses from a list in order, cycling back to the start
24/// when the end is reached.
25#[derive(Debug)]
26pub struct FakeListLLM {
27    /// List of responses to return in order.
28    responses: Vec<String>,
29    /// Sleep time in seconds between responses (ignored in base class).
30    sleep: Option<Duration>,
31    /// Current index (internally incremented after every invocation).
32    index: AtomicUsize,
33    /// LLM configuration.
34    config: LLMConfig,
35}
36
37impl Clone for FakeListLLM {
38    fn clone(&self) -> Self {
39        Self {
40            responses: self.responses.clone(),
41            sleep: self.sleep,
42            index: AtomicUsize::new(self.index.load(Ordering::SeqCst)),
43            config: self.config.clone(),
44        }
45    }
46}
47
48impl FakeListLLM {
49    /// Create a new FakeListLLM with the given responses.
50    pub fn new(responses: Vec<String>) -> Self {
51        Self {
52            responses,
53            sleep: None,
54            index: AtomicUsize::new(0),
55            config: LLMConfig::default(),
56        }
57    }
58
59    /// Set the sleep duration between responses.
60    pub fn with_sleep(mut self, duration: Duration) -> Self {
61        self.sleep = Some(duration);
62        self
63    }
64
65    /// Set the configuration.
66    pub fn with_config(mut self, config: LLMConfig) -> Self {
67        self.config = config;
68        self
69    }
70
71    /// Get the current index.
72    pub fn current_index(&self) -> usize {
73        self.index.load(Ordering::SeqCst)
74    }
75
76    /// Reset the index to 0.
77    pub fn reset(&self) {
78        self.index.store(0, Ordering::SeqCst);
79    }
80
81    /// Get the next response and advance the index.
82    fn get_next_response(&self) -> String {
83        let i = self.index.load(Ordering::SeqCst);
84        let response = self.responses.get(i).cloned().unwrap_or_default();
85
86        // Advance index, cycling back to start
87        let next_i = if i + 1 < self.responses.len() {
88            i + 1
89        } else {
90            0
91        };
92        self.index.store(next_i, Ordering::SeqCst);
93
94        response
95    }
96}
97
98#[async_trait]
99impl BaseLanguageModel for FakeListLLM {
100    fn llm_type(&self) -> &str {
101        "fake-list"
102    }
103
104    fn model_name(&self) -> &str {
105        "fake-list-llm"
106    }
107
108    fn config(&self) -> &LanguageModelConfig {
109        &self.config.base
110    }
111
112    fn cache(&self) -> Option<&dyn BaseCache> {
113        None
114    }
115
116    fn callbacks(&self) -> Option<&Callbacks> {
117        None
118    }
119
120    async fn generate_prompt(
121        &self,
122        prompts: Vec<LanguageModelInput>,
123        stop: Option<Vec<String>>,
124        _callbacks: Option<Callbacks>,
125    ) -> Result<LLMResult> {
126        let prompt_strings: Vec<String> = prompts.iter().map(|p| p.to_string()).collect();
127        self.generate_prompts(prompt_strings, stop, None).await
128    }
129
130    fn identifying_params(&self) -> HashMap<String, Value> {
131        let mut params = HashMap::new();
132        params.insert("_type".to_string(), Value::String("fake-list".to_string()));
133        params.insert(
134            "responses".to_string(),
135            serde_json::to_value(&self.responses).unwrap_or_default(),
136        );
137        params
138    }
139}
140
141#[async_trait]
142impl BaseLLM for FakeListLLM {
143    fn llm_config(&self) -> &LLMConfig {
144        &self.config
145    }
146
147    async fn generate_prompts(
148        &self,
149        prompts: Vec<String>,
150        _stop: Option<Vec<String>>,
151        _run_manager: Option<&CallbackManagerForLLMRun>,
152    ) -> Result<LLMResult> {
153        let mut generations = Vec::new();
154
155        for _ in prompts {
156            let response = self.get_next_response();
157            let generation = Generation::new(response);
158            generations.push(vec![GenerationType::Generation(generation)]);
159        }
160
161        Ok(LLMResult::new(generations))
162    }
163}
164
165#[async_trait]
166impl LLM for FakeListLLM {
167    async fn call(
168        &self,
169        _prompt: String,
170        _stop: Option<Vec<String>>,
171        _run_manager: Option<&CallbackManagerForLLMRun>,
172    ) -> Result<String> {
173        Ok(self.get_next_response())
174    }
175}
176
177/// Error raised by FakeStreamingListLLM during streaming.
178#[derive(Debug, Clone, thiserror::Error)]
179#[error("FakeListLLM error on chunk {0}")]
180pub struct FakeListLLMError(pub usize);
181
182/// Fake streaming list LLM for testing purposes.
183///
184/// An LLM that will return responses from a list in order,
185/// with support for streaming character by character.
186#[derive(Debug)]
187pub struct FakeStreamingListLLM {
188    /// Inner FakeListLLM.
189    inner: FakeListLLM,
190    /// If set, will raise an exception on the specified chunk number.
191    error_on_chunk_number: Option<usize>,
192}
193
194impl FakeStreamingListLLM {
195    /// Create a new FakeStreamingListLLM with the given responses.
196    pub fn new(responses: Vec<String>) -> Self {
197        Self {
198            inner: FakeListLLM::new(responses),
199            error_on_chunk_number: None,
200        }
201    }
202
203    /// Set the sleep duration between chunks.
204    pub fn with_sleep(mut self, duration: Duration) -> Self {
205        self.inner = self.inner.with_sleep(duration);
206        self
207    }
208
209    /// Set the configuration.
210    pub fn with_config(mut self, config: LLMConfig) -> Self {
211        self.inner = self.inner.with_config(config);
212        self
213    }
214
215    /// Set the chunk number to error on.
216    pub fn with_error_on_chunk(mut self, chunk_number: usize) -> Self {
217        self.error_on_chunk_number = Some(chunk_number);
218        self
219    }
220
221    /// Get the current index.
222    pub fn current_index(&self) -> usize {
223        self.inner.current_index()
224    }
225
226    /// Reset the index to 0.
227    pub fn reset(&self) {
228        self.inner.reset();
229    }
230}
231
232impl Clone for FakeStreamingListLLM {
233    fn clone(&self) -> Self {
234        Self {
235            inner: self.inner.clone(),
236            error_on_chunk_number: self.error_on_chunk_number,
237        }
238    }
239}
240
241#[async_trait]
242impl BaseLanguageModel for FakeStreamingListLLM {
243    fn llm_type(&self) -> &str {
244        "fake-streaming-list"
245    }
246
247    fn model_name(&self) -> &str {
248        "fake-streaming-list-llm"
249    }
250
251    fn config(&self) -> &LanguageModelConfig {
252        self.inner.config()
253    }
254
255    fn cache(&self) -> Option<&dyn BaseCache> {
256        None
257    }
258
259    fn callbacks(&self) -> Option<&Callbacks> {
260        None
261    }
262
263    async fn generate_prompt(
264        &self,
265        prompts: Vec<LanguageModelInput>,
266        stop: Option<Vec<String>>,
267        callbacks: Option<Callbacks>,
268    ) -> Result<LLMResult> {
269        self.inner.generate_prompt(prompts, stop, callbacks).await
270    }
271
272    fn identifying_params(&self) -> HashMap<String, Value> {
273        self.inner.identifying_params()
274    }
275}
276
277#[async_trait]
278impl BaseLLM for FakeStreamingListLLM {
279    fn llm_config(&self) -> &LLMConfig {
280        self.inner.llm_config()
281    }
282
283    async fn generate_prompts(
284        &self,
285        prompts: Vec<String>,
286        stop: Option<Vec<String>>,
287        run_manager: Option<&CallbackManagerForLLMRun>,
288    ) -> Result<LLMResult> {
289        self.inner
290            .generate_prompts(prompts, stop, run_manager)
291            .await
292    }
293
294    async fn stream_prompt(
295        &self,
296        prompt: String,
297        _stop: Option<Vec<String>>,
298        _run_manager: Option<&CallbackManagerForLLMRun>,
299    ) -> Result<LLMStream> {
300        // Get the response for this prompt
301        let response = self.inner.call(prompt, None, None).await?;
302        let sleep = self.inner.sleep;
303        let error_on_chunk = self.error_on_chunk_number;
304
305        // Create a stream that yields each character
306        let stream = async_stream::stream! {
307            for (i, c) in response.chars().enumerate() {
308                // Check if we should error on this chunk
309                if let Some(error_chunk) = error_on_chunk
310                    && i == error_chunk
311                {
312                    yield Err(crate::error::Error::Other(
313                        format!("FakeListLLM error on chunk {}", i)
314                    ));
315                    return;
316                }
317
318                // Sleep if configured
319                if let Some(duration) = sleep {
320                    tokio::time::sleep(duration).await;
321                }
322
323                yield Ok(GenerationChunk::new(c.to_string()));
324            }
325        };
326
327        Ok(Box::pin(stream))
328    }
329}
330
331#[async_trait]
332impl LLM for FakeStreamingListLLM {
333    async fn call(
334        &self,
335        prompt: String,
336        stop: Option<Vec<String>>,
337        run_manager: Option<&CallbackManagerForLLMRun>,
338    ) -> Result<String> {
339        self.inner.call(prompt, stop, run_manager).await
340    }
341}
342
343#[cfg(test)]
344mod tests {
345    use super::*;
346
347    #[tokio::test]
348    async fn test_fake_list_llm_responses() {
349        let llm = FakeListLLM::new(vec![
350            "Response 1".to_string(),
351            "Response 2".to_string(),
352            "Response 3".to_string(),
353        ]);
354
355        // First call
356        let result = llm.call("prompt".to_string(), None, None).await.unwrap();
357        assert_eq!(result, "Response 1");
358
359        // Second call
360        let result = llm.call("prompt".to_string(), None, None).await.unwrap();
361        assert_eq!(result, "Response 2");
362
363        // Third call
364        let result = llm.call("prompt".to_string(), None, None).await.unwrap();
365        assert_eq!(result, "Response 3");
366
367        // Fourth call (cycles back)
368        let result = llm.call("prompt".to_string(), None, None).await.unwrap();
369        assert_eq!(result, "Response 1");
370    }
371
372    #[tokio::test]
373    async fn test_fake_list_llm_reset() {
374        let llm = FakeListLLM::new(vec!["Response 1".to_string(), "Response 2".to_string()]);
375
376        // Advance index
377        let _ = llm.call("prompt".to_string(), None, None).await;
378        assert_eq!(llm.current_index(), 1);
379
380        // Reset
381        llm.reset();
382        assert_eq!(llm.current_index(), 0);
383
384        // Should get first response again
385        let result = llm.call("prompt".to_string(), None, None).await.unwrap();
386        assert_eq!(result, "Response 1");
387    }
388
389    #[tokio::test]
390    async fn test_fake_list_llm_generate_prompts() {
391        let llm = FakeListLLM::new(vec!["Response 1".to_string(), "Response 2".to_string()]);
392
393        let result = llm
394            .generate_prompts(
395                vec!["prompt1".to_string(), "prompt2".to_string()],
396                None,
397                None,
398            )
399            .await
400            .unwrap();
401
402        assert_eq!(result.generations.len(), 2);
403    }
404
405    #[tokio::test]
406    async fn test_fake_streaming_list_llm() {
407        use futures::StreamExt;
408
409        let llm = FakeStreamingListLLM::new(vec!["Hello".to_string()]);
410
411        let mut stream = llm
412            .stream_prompt("prompt".to_string(), None, None)
413            .await
414            .unwrap();
415
416        let mut result = String::new();
417        while let Some(chunk) = stream.next().await {
418            result.push_str(&chunk.unwrap().text);
419        }
420
421        assert_eq!(result, "Hello");
422    }
423
424    #[test]
425    fn test_fake_list_llm_identifying_params() {
426        let llm = FakeListLLM::new(vec!["Response".to_string()]);
427        let params = llm.identifying_params();
428
429        assert_eq!(params.get("_type").unwrap(), "fake-list");
430        assert!(params.contains_key("responses"));
431    }
432}