1use comp_cat_rs::effect::io::Io;
8use comp_cat_rs::effect::stream::Stream;
9
10use crate::error::Error;
11use crate::model::{
12 CompletionModel, CompletionRequest, Message, StreamChunk,
13};
14use crate::tool::{Tool, Toolbox};
15
16pub struct Agent<M: CompletionModel, T: Tool> {
22 model: M,
23 preamble: Option<String>,
24 tools: Toolbox<T>,
25 temperature: Option<f64>,
26 max_tokens: Option<u32>,
27}
28
29pub struct AgentBuilder<M: CompletionModel, T: Tool> {
31 model: M,
32 preamble: Option<String>,
33 tools: Toolbox<T>,
34 temperature: Option<f64>,
35 max_tokens: Option<u32>,
36}
37
38impl<M: CompletionModel, T: Tool> AgentBuilder<M, T> {
39 #[must_use]
41 pub fn new(model: M) -> Self {
42 Self {
43 model,
44 preamble: None,
45 tools: Toolbox::new(),
46 temperature: None,
47 max_tokens: None,
48 }
49 }
50
51 #[must_use]
53 pub fn preamble(self, preamble: impl Into<String>) -> Self {
54 Self { preamble: Some(preamble.into()), ..self }
55 }
56
57 #[must_use]
59 pub fn tools(self, tools: Toolbox<T>) -> Self {
60 Self { tools, ..self }
61 }
62
63 #[must_use]
65 pub fn temperature(self, t: f64) -> Self {
66 Self { temperature: Some(t), ..self }
67 }
68
69 #[must_use]
71 pub fn max_tokens(self, n: u32) -> Self {
72 Self { max_tokens: Some(n), ..self }
73 }
74
75 #[must_use]
77 pub fn build(self) -> Agent<M, T> {
78 Agent {
79 model: self.model,
80 preamble: self.preamble,
81 tools: self.tools,
82 temperature: self.temperature,
83 max_tokens: self.max_tokens,
84 }
85 }
86}
87
88impl<M: CompletionModel, T: Tool> Agent<M, T> {
89 pub fn prompt(&self, user_input: &str) -> Io<Error, String> {
91 let request = self.build_request(user_input);
92 self.model.complete(request).map(|r| r.content().to_owned())
93 }
94
95 pub fn prompt_stream(&self, user_input: &str) -> Stream<Error, StreamChunk> {
97 let request = self.build_request(user_input);
98 self.model.stream(request)
99 }
100
101 #[must_use]
103 pub fn tools(&self) -> &Toolbox<T> { &self.tools }
104
105 fn build_request(&self, user_input: &str) -> CompletionRequest {
106 let messages = self.preamble.iter()
107 .map(|p| Message::system(p.clone()))
108 .chain(std::iter::once(Message::user(user_input.to_owned())))
109 .collect();
110 let request = CompletionRequest::new(messages);
111 let request = match self.temperature {
112 Some(t) => request.with_temperature(t),
113 None => request,
114 };
115 match self.max_tokens {
116 Some(n) => request.with_max_tokens(n),
117 None => request,
118 }
119 }
120}
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125 use crate::tool::ToolDefinition;
126
127 struct FakeModel;
128
129 impl CompletionModel for FakeModel {
130 fn complete(&self, request: CompletionRequest) -> Io<Error, crate::model::CompletionResponse> {
131 let content = request.messages().iter()
132 .map(|m| m.content().to_owned())
133 .collect::<Vec<_>>()
134 .join("|");
135 Io::pure(crate::model::CompletionResponse::new(content, "fake".into()))
136 }
137
138 fn stream(&self, _request: CompletionRequest) -> Stream<Error, StreamChunk> {
139 Stream::empty()
140 }
141 }
142
143 struct FakeTool;
144
145 impl Tool for FakeTool {
146 fn definition(&self) -> ToolDefinition {
147 ToolDefinition::new("fake".into(), "fake".into(), serde_json::json!({}))
148 }
149 fn call(&self, _args: serde_json::Value) -> Io<Error, serde_json::Value> {
150 Io::pure(serde_json::json!({}))
151 }
152 }
153
154 #[test]
155 fn agent_includes_preamble_in_request() -> Result<(), Error> {
156 let agent: Agent<FakeModel, FakeTool> = AgentBuilder::new(FakeModel)
157 .preamble("You are helpful.")
158 .build();
159 let response = agent.prompt("hello").run()?;
160 assert!(response.contains("You are helpful."));
161 assert!(response.contains("hello"));
162 Ok(())
163 }
164
165 #[test]
166 fn agent_without_preamble_sends_only_user_message() -> Result<(), Error> {
167 let agent: Agent<FakeModel, FakeTool> = AgentBuilder::new(FakeModel).build();
168 let response = agent.prompt("hello").run()?;
169 assert_eq!(response, "hello");
170 Ok(())
171 }
172
173 #[test]
174 fn agent_applies_temperature_and_max_tokens() {
175 let agent: Agent<FakeModel, FakeTool> = AgentBuilder::new(FakeModel)
176 .temperature(0.5)
177 .max_tokens(100)
178 .build();
179 let request = agent.build_request("test");
180 assert!((request.temperature().unwrap_or(0.0) - 0.5).abs() < 1e-10);
181 assert_eq!(request.max_tokens(), Some(100));
182 }
183}