1use crate::tool::ToolDef;
7use crate::types::{Message, Role, SgrError, ToolCall};
8use serde_json::Value;
9
10#[async_trait::async_trait]
12pub trait LlmClient: Send + Sync {
13 async fn structured_call(
16 &self,
17 messages: &[Message],
18 schema: &Value,
19 ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError>;
20
21 async fn tools_call(
24 &self,
25 messages: &[Message],
26 tools: &[ToolDef],
27 ) -> Result<Vec<ToolCall>, SgrError>;
28
29 async fn tools_call_stateful(
33 &self,
34 messages: &[Message],
35 tools: &[ToolDef],
36 _previous_response_id: Option<&str>,
37 ) -> Result<(Vec<ToolCall>, Option<String>), SgrError> {
38 let calls = self.tools_call(messages, tools).await?;
40 Ok((calls, None))
41 }
42
43 async fn complete(&self, messages: &[Message]) -> Result<String, SgrError>;
45}
46
47pub fn synthesize_finish_if_empty(calls: &mut Vec<ToolCall>, content: &str) {
51 if calls.is_empty() {
52 let text = content.trim();
53 if !text.is_empty() {
54 calls.push(ToolCall {
55 id: "synth_finish".into(),
56 name: "finish".into(),
57 arguments: serde_json::json!({"summary": text}),
58 });
59 }
60 }
61}
62
63fn inject_schema(messages: &[Message], schema: &Value) -> Vec<Message> {
65 let schema_hint = format!(
66 "\n\nRespond with valid JSON matching this schema:\n{}\n\nDo NOT wrap in markdown code blocks. Output raw JSON only.",
67 serde_json::to_string_pretty(schema).unwrap_or_default()
68 );
69
70 let mut msgs = Vec::with_capacity(messages.len() + 1);
71 let mut injected = false;
72
73 for msg in messages {
74 if msg.role == Role::System && !injected {
75 msgs.push(Message::system(format!("{}{}", msg.content, schema_hint)));
77 injected = true;
78 } else {
79 msgs.push(msg.clone());
80 }
81 }
82
83 if !injected {
84 msgs.insert(0, Message::system(schema_hint));
86 }
87
88 msgs
89}
90
91#[cfg(feature = "gemini")]
92mod gemini_impl {
93 use super::*;
94 use crate::gemini::GeminiClient;
95
96 #[async_trait::async_trait]
97 impl LlmClient for GeminiClient {
98 async fn structured_call(
99 &self,
100 messages: &[Message],
101 schema: &Value,
102 ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
103 let msgs = inject_schema(messages, schema);
104 let resp = self.flexible::<Value>(&msgs).await?;
105 Ok((resp.output, resp.tool_calls, resp.raw_text))
106 }
107
108 async fn tools_call(
109 &self,
110 messages: &[Message],
111 tools: &[ToolDef],
112 ) -> Result<Vec<ToolCall>, SgrError> {
113 self.tools_call(messages, tools).await
114 }
115
116 async fn complete(&self, messages: &[Message]) -> Result<String, SgrError> {
117 let resp = self.flexible::<Value>(messages).await?;
118 Ok(resp.raw_text)
119 }
120 }
121}
122
123#[cfg(feature = "openai")]
124mod openai_impl {
125 use super::*;
126 use crate::openai::OpenAIClient;
127
128 #[async_trait::async_trait]
129 impl LlmClient for OpenAIClient {
130 async fn structured_call(
131 &self,
132 messages: &[Message],
133 schema: &Value,
134 ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
135 let msgs = inject_schema(messages, schema);
136 let resp = self.flexible::<Value>(&msgs).await?;
137 Ok((resp.output, resp.tool_calls, resp.raw_text))
138 }
139
140 async fn tools_call(
141 &self,
142 messages: &[Message],
143 tools: &[ToolDef],
144 ) -> Result<Vec<ToolCall>, SgrError> {
145 self.tools_call(messages, tools).await
146 }
147
148 async fn complete(&self, messages: &[Message]) -> Result<String, SgrError> {
149 let resp = self.flexible::<Value>(messages).await?;
150 Ok(resp.raw_text)
151 }
152 }
153}
154
155#[cfg(test)]
156mod tests {
157 use super::*;
158 use crate::tool::ToolDef;
159
160 struct MockStatelessClient;
163
164 #[async_trait::async_trait]
165 impl LlmClient for MockStatelessClient {
166 async fn structured_call(
167 &self,
168 _: &[Message],
169 _: &Value,
170 ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
171 Ok((None, vec![], String::new()))
172 }
173 async fn tools_call(
174 &self,
175 _: &[Message],
176 _: &[ToolDef],
177 ) -> Result<Vec<ToolCall>, SgrError> {
178 Ok(vec![ToolCall {
179 id: "tc1".into(),
180 name: "test_tool".into(),
181 arguments: serde_json::json!({"x": 1}),
182 }])
183 }
184 async fn complete(&self, _: &[Message]) -> Result<String, SgrError> {
185 Ok(String::new())
186 }
187 }
188
189 #[tokio::test]
190 async fn tools_call_stateful_default_delegates() {
191 let client = MockStatelessClient;
192 let msgs = vec![Message::user("hi")];
193 let tools = vec![ToolDef {
194 name: "test_tool".into(),
195 description: "test".into(),
196 parameters: serde_json::json!({"type": "object"}),
197 }];
198
199 let (calls, response_id) = client
201 .tools_call_stateful(&msgs, &tools, None)
202 .await
203 .unwrap();
204 assert_eq!(calls.len(), 1);
205 assert_eq!(calls[0].name, "test_tool");
206 assert!(response_id.is_none(), "default impl returns no response_id");
207
208 let (calls, response_id) = client
210 .tools_call_stateful(&msgs, &tools, Some("resp_abc"))
211 .await
212 .unwrap();
213 assert_eq!(calls.len(), 1);
214 assert!(response_id.is_none());
215 }
216
217 #[test]
218 fn inject_schema_appends_to_existing_system() {
219 let msgs = vec![
220 Message::system("You are a coding agent."),
221 Message::user("hello"),
222 ];
223 let schema = serde_json::json!({"type": "object"});
224 let result = inject_schema(&msgs, &schema);
225
226 assert_eq!(result.len(), 2);
227 assert!(result[0].content.contains("You are a coding agent."));
228 assert!(result[0].content.contains("Respond with valid JSON"));
229 assert_eq!(result[0].role, Role::System);
230 }
231
232 #[test]
233 fn inject_schema_prepends_when_no_system() {
234 let msgs = vec![Message::user("hello")];
235 let schema = serde_json::json!({"type": "object"});
236 let result = inject_schema(&msgs, &schema);
237
238 assert_eq!(result.len(), 2);
239 assert_eq!(result[0].role, Role::System);
240 assert!(result[0].content.contains("Respond with valid JSON"));
241 assert_eq!(result[1].role, Role::User);
242 }
243
244 #[test]
245 fn inject_schema_only_first_system_message() {
246 let msgs = vec![
247 Message::system("System 1"),
248 Message::user("msg"),
249 Message::system("System 2"),
250 ];
251 let schema = serde_json::json!({"type": "object"});
252 let result = inject_schema(&msgs, &schema);
253
254 assert_eq!(result.len(), 3);
255 assert!(result[0].content.contains("Respond with valid JSON"));
257 assert_eq!(result[2].content, "System 2");
259 }
260}