Skip to main content

rig_cat/agent/
mod.rs

1//! Agent: the high-level LLM abstraction.
2//!
3//! An Agent composes a completion model, a system preamble,
4//! optional tools, and an optional vector store into a single
5//! `.prompt()` call that returns `Io<Error, String>`.
6
7use comp_cat_rs::effect::io::Io;
8use comp_cat_rs::effect::stream::Stream;
9
10use crate::error::Error;
11use crate::model::{
12    CompletionModel, CompletionRequest, Message, StreamChunk,
13};
14use crate::tool::{Tool, Toolbox};
15
16/// An agent: a configured LLM with preamble, tools, and context.
17///
18/// Built via `AgentBuilder`.  Generic over:
19/// - `M`: the completion model
20/// - `T`: the tool type (use an enum for heterogeneous tools)
21pub struct Agent<M: CompletionModel, T: Tool> {
22    model: M,
23    preamble: Option<String>,
24    tools: Toolbox<T>,
25    temperature: Option<f64>,
26    max_tokens: Option<u32>,
27}
28
29/// Builder for constructing an `Agent`.
30pub struct AgentBuilder<M: CompletionModel, T: Tool> {
31    model: M,
32    preamble: Option<String>,
33    tools: Toolbox<T>,
34    temperature: Option<f64>,
35    max_tokens: Option<u32>,
36}
37
38impl<M: CompletionModel, T: Tool> AgentBuilder<M, T> {
39    /// Start building an agent with the given model.
40    #[must_use]
41    pub fn new(model: M) -> Self {
42        Self {
43            model,
44            preamble: None,
45            tools: Toolbox::new(),
46            temperature: None,
47            max_tokens: None,
48        }
49    }
50
51    /// Set the system preamble.
52    #[must_use]
53    pub fn preamble(self, preamble: impl Into<String>) -> Self {
54        Self { preamble: Some(preamble.into()), ..self }
55    }
56
57    /// Set the tools available to the agent.
58    #[must_use]
59    pub fn tools(self, tools: Toolbox<T>) -> Self {
60        Self { tools, ..self }
61    }
62
63    /// Set the temperature.
64    #[must_use]
65    pub fn temperature(self, t: f64) -> Self {
66        Self { temperature: Some(t), ..self }
67    }
68
69    /// Set the max tokens.
70    #[must_use]
71    pub fn max_tokens(self, n: u32) -> Self {
72        Self { max_tokens: Some(n), ..self }
73    }
74
75    /// Build the agent.
76    #[must_use]
77    pub fn build(self) -> Agent<M, T> {
78        Agent {
79            model: self.model,
80            preamble: self.preamble,
81            tools: self.tools,
82            temperature: self.temperature,
83            max_tokens: self.max_tokens,
84        }
85    }
86}
87
88impl<M: CompletionModel, T: Tool> Agent<M, T> {
89    /// Send a prompt and get a complete response.
90    pub fn prompt(&self, user_input: &str) -> Io<Error, String> {
91        let request = self.build_request(user_input);
92        self.model.complete(request).map(|r| r.content().to_owned())
93    }
94
95    /// Send a prompt and get a streaming response.
96    pub fn prompt_stream(&self, user_input: &str) -> Stream<Error, StreamChunk> {
97        let request = self.build_request(user_input);
98        self.model.stream(request)
99    }
100
101    /// Access the toolbox.
102    #[must_use]
103    pub fn tools(&self) -> &Toolbox<T> { &self.tools }
104
105    fn build_request(&self, user_input: &str) -> CompletionRequest {
106        let messages = self.preamble.iter()
107            .map(|p| Message::system(p.clone()))
108            .chain(std::iter::once(Message::user(user_input.to_owned())))
109            .collect();
110        let request = CompletionRequest::new(messages);
111        let request = match self.temperature {
112            Some(t) => request.with_temperature(t),
113            None => request,
114        };
115        match self.max_tokens {
116            Some(n) => request.with_max_tokens(n),
117            None => request,
118        }
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125    use crate::tool::ToolDefinition;
126
127    struct FakeModel;
128
129    impl CompletionModel for FakeModel {
130        fn complete(&self, request: CompletionRequest) -> Io<Error, crate::model::CompletionResponse> {
131            let content = request.messages().iter()
132                .map(|m| m.content().to_owned())
133                .collect::<Vec<_>>()
134                .join("|");
135            Io::pure(crate::model::CompletionResponse::new(content, "fake".into()))
136        }
137
138        fn stream(&self, _request: CompletionRequest) -> Stream<Error, StreamChunk> {
139            Stream::empty()
140        }
141    }
142
143    struct FakeTool;
144
145    impl Tool for FakeTool {
146        fn definition(&self) -> ToolDefinition {
147            ToolDefinition::new("fake".into(), "fake".into(), serde_json::json!({}))
148        }
149        fn call(&self, _args: serde_json::Value) -> Io<Error, serde_json::Value> {
150            Io::pure(serde_json::json!({}))
151        }
152    }
153
154    #[test]
155    fn agent_includes_preamble_in_request() -> Result<(), Error> {
156        let agent: Agent<FakeModel, FakeTool> = AgentBuilder::new(FakeModel)
157            .preamble("You are helpful.")
158            .build();
159        let response = agent.prompt("hello").run()?;
160        assert!(response.contains("You are helpful."));
161        assert!(response.contains("hello"));
162        Ok(())
163    }
164
165    #[test]
166    fn agent_without_preamble_sends_only_user_message() -> Result<(), Error> {
167        let agent: Agent<FakeModel, FakeTool> = AgentBuilder::new(FakeModel).build();
168        let response = agent.prompt("hello").run()?;
169        assert_eq!(response, "hello");
170        Ok(())
171    }
172
173    #[test]
174    fn agent_applies_temperature_and_max_tokens() {
175        let agent: Agent<FakeModel, FakeTool> = AgentBuilder::new(FakeModel)
176            .temperature(0.5)
177            .max_tokens(100)
178            .build();
179        let request = agent.build_request("test");
180        assert!((request.temperature().unwrap_or(0.0) - 0.5).abs() < 1e-10);
181        assert_eq!(request.max_tokens(), Some(100));
182    }
183}