1pub mod context;
44pub mod tools;
45
46pub use async_trait::async_trait;
48pub use genai;
49
50use std::marker::PhantomData;
51
52use crate::context::{ChatHistory, ContextProvider, SystemPromptGenerator};
53use genai::chat::{ChatMessage, ChatRequest, ChatResponseFormat, ChatRole, JsonSpec};
54use schemars::JsonSchema;
55use serde::{Deserialize, Serialize};
56
57#[derive(Serialize, Deserialize, JsonSchema)]
59pub struct BasicChatInputSchema {
60 pub chat_message: String,
61}
62
63#[derive(Serialize, Deserialize, JsonSchema)]
65pub struct BasicChatOutputSchema {
66 pub chat_message: String,
67}
68
69#[derive(Debug, thiserror::Error)]
71pub enum AgentError {
72 #[error("Invalid response: {0}")]
73 InvalidResponse(String),
74 #[error("Request failed: {0}")]
75 RequestFailed(String),
76}
77
78pub struct AgentInner {
80 config: AgentConfig,
81}
82
83impl AgentInner {
84 fn new(config: AgentConfig) -> Self {
85 Self { config }
86 }
87
88 async fn run_turn<I, O>(&mut self, input: I) -> Result<O, AgentError>
89 where
90 I: Serialize + JsonSchema + Send,
91 O: Serialize + for<'de> Deserialize<'de> + JsonSchema + Send,
92 {
93 self.config.run_turn(input).await
94 }
95}
96
97pub struct AgentConfig {
99 model: String,
100 system_prompt_generator: SystemPromptGenerator,
101 chat_history: Option<ChatHistory>,
102 chat_options: genai::chat::ChatOptions,
103 genai_client: genai::Client,
104}
105
106impl AgentConfig {
107 pub fn new(model: impl Into<String>) -> Self {
109 Self {
110 model: model.into(),
111 system_prompt_generator: SystemPromptGenerator::new(),
112 chat_history: None,
113 chat_options: genai::chat::ChatOptions::default(),
114 genai_client: genai::Client::builder().build(),
115 }
116 }
117
118 pub fn with_system_prompt_generator(
120 mut self,
121 system_prompt_generator: SystemPromptGenerator,
122 ) -> Self {
123 self.system_prompt_generator = system_prompt_generator;
124 self
125 }
126
127 pub fn with_chat_history(mut self, chat_history: ChatHistory) -> Self {
129 self.chat_history = Some(chat_history);
130 self
131 }
132
133 pub fn with_chat_options(mut self, chat_options: genai::chat::ChatOptions) -> Self {
135 self.chat_options = chat_options;
136 self
137 }
138
139 fn system_prompt_with_output_schema<O: JsonSchema>(&self) -> String {
141 let system_prompt = self.system_prompt_generator.generate();
142 let schema_context = output_schema_instructions::<O>();
143 format!("{}\n\n{}", system_prompt, schema_context)
144 }
145
146 pub(crate) fn prepare_messages<I: Serialize + JsonSchema, O: JsonSchema>(
148 &self,
149 input: &I,
150 ) -> Vec<ChatMessage> {
151 let system_message = self.system_prompt_with_output_schema::<O>();
152 let mut messages = vec![ChatMessage {
153 role: ChatRole::System,
154 content: system_message.into(),
155 options: None,
156 }];
157 if let Some(chat_history) = &self.chat_history {
158 messages.extend(chat_history.get_history().to_vec());
159 }
160 messages.push(ChatMessage {
161 role: ChatRole::User,
162 content: serde_json::to_string_pretty(input).unwrap().into(),
163 options: None,
164 });
165 messages
166 }
167
168 async fn run_turn<I, O>(&mut self, input: I) -> Result<O, AgentError>
170 where
171 I: Serialize + JsonSchema + Send,
172 O: Serialize + for<'de> Deserialize<'de> + JsonSchema + Send,
173 {
174 let messages = self.prepare_messages::<I, O>(&input);
175
176 let options = self
177 .chat_options
178 .clone()
179 .with_response_format(ChatResponseFormat::JsonSpec(JsonSpec {
180 name: "Output schema".to_string(),
181 description: None,
182 schema: schemars::schema_for!(O).into(),
183 }));
184
185 tracing::debug!(
186 "Sending messages to the model: {}",
187 serde_json::to_string_pretty(&messages).unwrap()
188 );
189
190 let response = self
191 .genai_client
192 .exec_chat(&self.model, ChatRequest::new(messages), Some(&options))
193 .await
194 .map_err(|e| AgentError::RequestFailed(e.to_string()))?;
195
196 let response_text = response.first_text().ok_or_else(|| {
197 AgentError::InvalidResponse("LLM response did not contain text content".to_string())
198 })?;
199 let parsed: O = serde_json::from_str(response_text)
200 .map_err(|e| AgentError::InvalidResponse(e.to_string()))?;
201
202 tracing::debug!(
203 "Received response: {}",
204 serde_json::to_string_pretty(&parsed).unwrap()
205 );
206
207 if let Some(chat_history) = &mut self.chat_history {
208 chat_history.add_message(
209 ChatRole::User,
210 serde_json::to_string_pretty(&input).unwrap().into(),
211 );
212 chat_history.add_message(ChatRole::Assistant, response_text.into());
213 }
214
215 Ok(parsed)
216 }
217}
218
219pub struct BasicNanoAgent<
225 I: Serialize + JsonSchema + Send + 'static = BasicChatInputSchema,
226 O: for<'de> Deserialize<'de> + JsonSchema + Send + 'static = BasicChatOutputSchema,
227> {
228 inner: AgentInner,
229 _io: PhantomData<fn() -> (I, O)>,
230}
231
232impl<I, O> BasicNanoAgent<I, O>
233where
234 I: Serialize + JsonSchema + Send + 'static,
235 O: for<'de> Deserialize<'de> + JsonSchema + Send + 'static,
236{
237 pub fn new(config: AgentConfig) -> Self {
238 Self {
239 inner: AgentInner::new(config),
240 _io: PhantomData,
241 }
242 }
243}
244
245#[async_trait]
252pub trait NanoAgent<
253 I: Serialize + JsonSchema + Send + 'static = BasicChatInputSchema,
254 O: Serialize + for<'de> Deserialize<'de> + JsonSchema + Send + 'static = BasicChatOutputSchema,
255>: Send
256{
257 fn get_inner(&self) -> &AgentInner;
258 fn get_inner_mut(&mut self) -> &mut AgentInner;
259
260 fn register_context_provider(&mut self, provider: impl ContextProvider + Send + 'static) {
261 self.get_inner_mut()
262 .config
263 .system_prompt_generator
264 .get_context_providers_mut()
265 .push(Box::new(provider));
266 }
267
268 async fn run(&mut self, input: I) -> Result<O, AgentError> {
270 self.get_inner_mut().run_turn(input).await
271 }
272}
273
274impl<I, O> NanoAgent<I, O> for BasicNanoAgent<I, O>
275where
276 I: Serialize + JsonSchema + Send + 'static,
277 O: Serialize + for<'de> Deserialize<'de> + JsonSchema + Send + 'static,
278{
279 fn get_inner(&self) -> &AgentInner {
280 &self.inner
281 }
282
283 fn get_inner_mut(&mut self) -> &mut AgentInner {
284 &mut self.inner
285 }
286}
287
288fn output_schema_instructions<T: JsonSchema>() -> String {
289 let output_schema = schemars::schema_for!(T);
290 format!(
291 "Understand the request and respond with a single object that matches the following schema.
292 Fill string fields with **new assistant-authored** text as appropriate; do not copy the user's wording into those fields unless the task explicitly requires repetition.
293
294 {}
295
296 Return only a response that validates against this schema, not the schema itself.",
297 serde_json::to_string_pretty(&output_schema).unwrap()
298 )
299}
300
301#[cfg(test)]
302mod tests {
303 use super::*;
304 use crate::context::ChatHistory;
305 use genai::chat::{ChatRole, MessageContent};
306
307 fn config_without_history() -> AgentConfig {
308 AgentConfig::new("test-model")
309 }
310
311 #[test]
312 fn prepare_messages_order_without_history_is_system_then_user() {
313 let cfg = config_without_history();
314 let input = BasicChatInputSchema {
315 chat_message: "hello".into(),
316 };
317 let msgs = cfg.prepare_messages::<BasicChatInputSchema, BasicChatOutputSchema>(&input);
318 assert_eq!(msgs.len(), 2);
319 assert_eq!(msgs[0].role, ChatRole::System);
320 assert_eq!(msgs[1].role, ChatRole::User);
321 assert!(msgs[0].options.is_none());
322 assert!(msgs[1].options.is_none());
323 }
324
325 #[test]
326 fn prepare_messages_user_content_is_pretty_json_of_input() {
327 let cfg = config_without_history();
328 let input = BasicChatInputSchema {
329 chat_message: "ping".into(),
330 };
331 let expected = serde_json::to_string_pretty(&input).unwrap();
332 let msgs = cfg.prepare_messages::<BasicChatInputSchema, BasicChatOutputSchema>(&input);
333 let user_text = msgs[1].content.first_text().expect("user message text");
334 assert_eq!(user_text, expected.as_str());
335 assert!(user_text.contains("ping"));
336 }
337
338 #[test]
339 fn prepare_messages_system_includes_output_schema_instructions() {
340 let cfg = config_without_history();
341 let input = BasicChatInputSchema {
342 chat_message: "x".into(),
343 };
344 let msgs = cfg.prepare_messages::<BasicChatInputSchema, BasicChatOutputSchema>(&input);
345 let system = msgs[0].content.first_text().expect("system message text");
346 assert!(
347 system.contains("matches the following schema"),
348 "system prompt should embed schema instructions"
349 );
350 assert!(
351 system.contains("chat_message"),
352 "system prompt should include output JSON schema field names"
353 );
354 }
355
356 #[test]
357 fn prepare_messages_inserts_chat_history_between_system_and_user() {
358 let mut history = ChatHistory::new();
359 history.add_message(ChatRole::User, MessageContent::from_text("prior user"));
360 history.add_message(
361 ChatRole::Assistant,
362 MessageContent::from_text("prior assistant"),
363 );
364 let cfg = AgentConfig::new("test-model").with_chat_history(history);
365 let input = BasicChatInputSchema {
366 chat_message: "latest".into(),
367 };
368 let msgs = cfg.prepare_messages::<BasicChatInputSchema, BasicChatOutputSchema>(&input);
369 assert_eq!(msgs.len(), 4);
370 assert_eq!(msgs[0].role, ChatRole::System);
371 assert_eq!(
372 msgs[1].content.first_text(),
373 Some("prior user"),
374 "first history turn"
375 );
376 assert_eq!(
377 msgs[2].content.first_text(),
378 Some("prior assistant"),
379 "second history turn"
380 );
381 assert_eq!(msgs[3].role, ChatRole::User);
382 let latest = msgs[3].content.first_text().expect("current user turn");
383 assert!(
384 latest.contains("latest"),
385 "final user message should be serialized input"
386 );
387 }
388}