chimerai/llm/
mod.rs

1use std::pin::Pin;
2
3use anyhow::Result;
4use async_trait::async_trait;
5use futures::Stream;
6
7use crate::tools::Tool;
8use crate::types::{Decision, Message};
9
10#[async_trait]
11pub trait LLMClient: Send + Sync {
12    async fn complete(
13        &self,
14        messages: &[Message],
15        tools: Vec<&Box<dyn Tool>>,
16        max_tokens: Option<usize>,
17    ) -> Result<Decision>;
18
19    async fn stream_complete(
20        &self,
21        messages: &[Message],
22        tools: Vec<&Box<dyn Tool>>,
23        max_tokens: Option<usize>,
24    ) -> Result<Pin<Box<dyn Stream<Item = Result<Decision>> + Send>>>;
25}
26
27#[cfg(test)]
28pub(crate) mod tests {
29    use super::*;
30
31    #[derive(Debug, Default)]
32    pub struct MockLLMClient;
33
34    impl MockLLMClient {
35        pub fn new() -> Self {
36            Self
37        }
38    }
39
40    #[async_trait]
41    impl LLMClient for MockLLMClient {
42        async fn complete(
43            &self,
44            messages: &[Message],
45            _tools: Vec<&Box<dyn Tool>>,
46            _max_tokens: Option<usize>,
47        ) -> Result<Decision> {
48            if let Some(Message::User { content }) = messages.last() {
49                Ok(Decision::Respond(format!("Echo: {}", content)))
50            } else {
51                Ok(Decision::Respond("No messages provided".to_string()))
52            }
53        }
54
55        async fn stream_complete(
56            &self,
57            messages: &[Message],
58            tools: Vec<&Box<dyn Tool>>,
59            max_tokens: Option<usize>,
60        ) -> Result<Pin<Box<dyn Stream<Item = Result<Decision>> + Send>>> {
61            let response = self.complete(messages, tools, max_tokens).await?;
62            Ok(Box::pin(futures::stream::once(async move { Ok(response) })))
63        }
64    }
65
66    #[tokio::test]
67    async fn test_mock_llm_client() {
68        let client = MockLLMClient::new();
69        let message = Message::User {
70            content: "Hello".to_string(),
71        };
72        let messages = vec![message];
73
74        let response = client.complete(&messages, vec![], Some(100)).await.unwrap();
75
76        match response {
77            Decision::Respond(response) => {
78                assert_eq!(response, "Echo: Hello");
79            }
80            _ => panic!("Expected Respond variant"),
81        }
82
83        // // Test stream
84        // let mut stream = client
85        //     .stream_complete(&messages, vec![], Some(100))
86        //     .await
87        //     .unwrap();
88
89        // if let Some(Ok(decision)) = stream.next().await {
90        //     match decision {
91        //         Decision::Respond(response) => {
92        //             assert_eq!(response, "Echo: Hello");
93        //         }
94        //         _ => panic!("Expected Respond variant"),
95        //     }
96        // } else {
97        //     panic!("Expected a chunk from stream");
98        // }
99    }
100}