agent_chain_core/language_models/
fake.rs1use std::collections::HashMap;
8use std::sync::atomic::{AtomicUsize, Ordering};
9use std::time::Duration;
10
11use async_trait::async_trait;
12use serde_json::Value;
13
14use super::base::{BaseLanguageModel, LanguageModelConfig, LanguageModelInput};
15use super::llms::{BaseLLM, LLM, LLMConfig, LLMStream};
16use crate::caches::BaseCache;
17use crate::callbacks::{CallbackManagerForLLMRun, Callbacks};
18use crate::error::Result;
19use crate::outputs::{Generation, GenerationChunk, GenerationType, LLMResult};
20
21#[derive(Debug)]
26pub struct FakeListLLM {
27 responses: Vec<String>,
29 sleep: Option<Duration>,
31 index: AtomicUsize,
33 config: LLMConfig,
35}
36
37impl Clone for FakeListLLM {
38 fn clone(&self) -> Self {
39 Self {
40 responses: self.responses.clone(),
41 sleep: self.sleep,
42 index: AtomicUsize::new(self.index.load(Ordering::SeqCst)),
43 config: self.config.clone(),
44 }
45 }
46}
47
48impl FakeListLLM {
49 pub fn new(responses: Vec<String>) -> Self {
51 Self {
52 responses,
53 sleep: None,
54 index: AtomicUsize::new(0),
55 config: LLMConfig::default(),
56 }
57 }
58
59 pub fn with_sleep(mut self, duration: Duration) -> Self {
61 self.sleep = Some(duration);
62 self
63 }
64
65 pub fn with_config(mut self, config: LLMConfig) -> Self {
67 self.config = config;
68 self
69 }
70
71 pub fn current_index(&self) -> usize {
73 self.index.load(Ordering::SeqCst)
74 }
75
76 pub fn reset(&self) {
78 self.index.store(0, Ordering::SeqCst);
79 }
80
81 fn get_next_response(&self) -> String {
83 let i = self.index.load(Ordering::SeqCst);
84 let response = self.responses.get(i).cloned().unwrap_or_default();
85
86 let next_i = if i + 1 < self.responses.len() {
88 i + 1
89 } else {
90 0
91 };
92 self.index.store(next_i, Ordering::SeqCst);
93
94 response
95 }
96}
97
98#[async_trait]
99impl BaseLanguageModel for FakeListLLM {
100 fn llm_type(&self) -> &str {
101 "fake-list"
102 }
103
104 fn model_name(&self) -> &str {
105 "fake-list-llm"
106 }
107
108 fn config(&self) -> &LanguageModelConfig {
109 &self.config.base
110 }
111
112 fn cache(&self) -> Option<&dyn BaseCache> {
113 None
114 }
115
116 fn callbacks(&self) -> Option<&Callbacks> {
117 None
118 }
119
120 async fn generate_prompt(
121 &self,
122 prompts: Vec<LanguageModelInput>,
123 stop: Option<Vec<String>>,
124 _callbacks: Option<Callbacks>,
125 ) -> Result<LLMResult> {
126 let prompt_strings: Vec<String> = prompts.iter().map(|p| p.to_string()).collect();
127 self.generate_prompts(prompt_strings, stop, None).await
128 }
129
130 fn identifying_params(&self) -> HashMap<String, Value> {
131 let mut params = HashMap::new();
132 params.insert("_type".to_string(), Value::String("fake-list".to_string()));
133 params.insert(
134 "responses".to_string(),
135 serde_json::to_value(&self.responses).unwrap_or_default(),
136 );
137 params
138 }
139}
140
141#[async_trait]
142impl BaseLLM for FakeListLLM {
143 fn llm_config(&self) -> &LLMConfig {
144 &self.config
145 }
146
147 async fn generate_prompts(
148 &self,
149 prompts: Vec<String>,
150 _stop: Option<Vec<String>>,
151 _run_manager: Option<&CallbackManagerForLLMRun>,
152 ) -> Result<LLMResult> {
153 let mut generations = Vec::new();
154
155 for _ in prompts {
156 let response = self.get_next_response();
157 let generation = Generation::new(response);
158 generations.push(vec![GenerationType::Generation(generation)]);
159 }
160
161 Ok(LLMResult::new(generations))
162 }
163}
164
165#[async_trait]
166impl LLM for FakeListLLM {
167 async fn call(
168 &self,
169 _prompt: String,
170 _stop: Option<Vec<String>>,
171 _run_manager: Option<&CallbackManagerForLLMRun>,
172 ) -> Result<String> {
173 Ok(self.get_next_response())
174 }
175}
176
177#[derive(Debug, Clone, thiserror::Error)]
179#[error("FakeListLLM error on chunk {0}")]
180pub struct FakeListLLMError(pub usize);
181
182#[derive(Debug)]
187pub struct FakeStreamingListLLM {
188 inner: FakeListLLM,
190 error_on_chunk_number: Option<usize>,
192}
193
194impl FakeStreamingListLLM {
195 pub fn new(responses: Vec<String>) -> Self {
197 Self {
198 inner: FakeListLLM::new(responses),
199 error_on_chunk_number: None,
200 }
201 }
202
203 pub fn with_sleep(mut self, duration: Duration) -> Self {
205 self.inner = self.inner.with_sleep(duration);
206 self
207 }
208
209 pub fn with_config(mut self, config: LLMConfig) -> Self {
211 self.inner = self.inner.with_config(config);
212 self
213 }
214
215 pub fn with_error_on_chunk(mut self, chunk_number: usize) -> Self {
217 self.error_on_chunk_number = Some(chunk_number);
218 self
219 }
220
221 pub fn current_index(&self) -> usize {
223 self.inner.current_index()
224 }
225
226 pub fn reset(&self) {
228 self.inner.reset();
229 }
230}
231
232impl Clone for FakeStreamingListLLM {
233 fn clone(&self) -> Self {
234 Self {
235 inner: self.inner.clone(),
236 error_on_chunk_number: self.error_on_chunk_number,
237 }
238 }
239}
240
241#[async_trait]
242impl BaseLanguageModel for FakeStreamingListLLM {
243 fn llm_type(&self) -> &str {
244 "fake-streaming-list"
245 }
246
247 fn model_name(&self) -> &str {
248 "fake-streaming-list-llm"
249 }
250
251 fn config(&self) -> &LanguageModelConfig {
252 self.inner.config()
253 }
254
255 fn cache(&self) -> Option<&dyn BaseCache> {
256 None
257 }
258
259 fn callbacks(&self) -> Option<&Callbacks> {
260 None
261 }
262
263 async fn generate_prompt(
264 &self,
265 prompts: Vec<LanguageModelInput>,
266 stop: Option<Vec<String>>,
267 callbacks: Option<Callbacks>,
268 ) -> Result<LLMResult> {
269 self.inner.generate_prompt(prompts, stop, callbacks).await
270 }
271
272 fn identifying_params(&self) -> HashMap<String, Value> {
273 self.inner.identifying_params()
274 }
275}
276
277#[async_trait]
278impl BaseLLM for FakeStreamingListLLM {
279 fn llm_config(&self) -> &LLMConfig {
280 self.inner.llm_config()
281 }
282
283 async fn generate_prompts(
284 &self,
285 prompts: Vec<String>,
286 stop: Option<Vec<String>>,
287 run_manager: Option<&CallbackManagerForLLMRun>,
288 ) -> Result<LLMResult> {
289 self.inner
290 .generate_prompts(prompts, stop, run_manager)
291 .await
292 }
293
294 async fn stream_prompt(
295 &self,
296 prompt: String,
297 _stop: Option<Vec<String>>,
298 _run_manager: Option<&CallbackManagerForLLMRun>,
299 ) -> Result<LLMStream> {
300 let response = self.inner.call(prompt, None, None).await?;
302 let sleep = self.inner.sleep;
303 let error_on_chunk = self.error_on_chunk_number;
304
305 let stream = async_stream::stream! {
307 for (i, c) in response.chars().enumerate() {
308 if let Some(error_chunk) = error_on_chunk
310 && i == error_chunk
311 {
312 yield Err(crate::error::Error::Other(
313 format!("FakeListLLM error on chunk {}", i)
314 ));
315 return;
316 }
317
318 if let Some(duration) = sleep {
320 tokio::time::sleep(duration).await;
321 }
322
323 yield Ok(GenerationChunk::new(c.to_string()));
324 }
325 };
326
327 Ok(Box::pin(stream))
328 }
329}
330
331#[async_trait]
332impl LLM for FakeStreamingListLLM {
333 async fn call(
334 &self,
335 prompt: String,
336 stop: Option<Vec<String>>,
337 run_manager: Option<&CallbackManagerForLLMRun>,
338 ) -> Result<String> {
339 self.inner.call(prompt, stop, run_manager).await
340 }
341}
342
343#[cfg(test)]
344mod tests {
345 use super::*;
346
347 #[tokio::test]
348 async fn test_fake_list_llm_responses() {
349 let llm = FakeListLLM::new(vec![
350 "Response 1".to_string(),
351 "Response 2".to_string(),
352 "Response 3".to_string(),
353 ]);
354
355 let result = llm.call("prompt".to_string(), None, None).await.unwrap();
357 assert_eq!(result, "Response 1");
358
359 let result = llm.call("prompt".to_string(), None, None).await.unwrap();
361 assert_eq!(result, "Response 2");
362
363 let result = llm.call("prompt".to_string(), None, None).await.unwrap();
365 assert_eq!(result, "Response 3");
366
367 let result = llm.call("prompt".to_string(), None, None).await.unwrap();
369 assert_eq!(result, "Response 1");
370 }
371
372 #[tokio::test]
373 async fn test_fake_list_llm_reset() {
374 let llm = FakeListLLM::new(vec!["Response 1".to_string(), "Response 2".to_string()]);
375
376 let _ = llm.call("prompt".to_string(), None, None).await;
378 assert_eq!(llm.current_index(), 1);
379
380 llm.reset();
382 assert_eq!(llm.current_index(), 0);
383
384 let result = llm.call("prompt".to_string(), None, None).await.unwrap();
386 assert_eq!(result, "Response 1");
387 }
388
389 #[tokio::test]
390 async fn test_fake_list_llm_generate_prompts() {
391 let llm = FakeListLLM::new(vec!["Response 1".to_string(), "Response 2".to_string()]);
392
393 let result = llm
394 .generate_prompts(
395 vec!["prompt1".to_string(), "prompt2".to_string()],
396 None,
397 None,
398 )
399 .await
400 .unwrap();
401
402 assert_eq!(result.generations.len(), 2);
403 }
404
405 #[tokio::test]
406 async fn test_fake_streaming_list_llm() {
407 use futures::StreamExt;
408
409 let llm = FakeStreamingListLLM::new(vec!["Hello".to_string()]);
410
411 let mut stream = llm
412 .stream_prompt("prompt".to_string(), None, None)
413 .await
414 .unwrap();
415
416 let mut result = String::new();
417 while let Some(chunk) = stream.next().await {
418 result.push_str(&chunk.unwrap().text);
419 }
420
421 assert_eq!(result, "Hello");
422 }
423
424 #[test]
425 fn test_fake_list_llm_identifying_params() {
426 let llm = FakeListLLM::new(vec!["Response".to_string()]);
427 let params = llm.identifying_params();
428
429 assert_eq!(params.get("_type").unwrap(), "fake-list");
430 assert!(params.contains_key("responses"));
431 }
432}