Skip to main content

sgr_agent/
client.rs

1//! LlmClient trait — abstract LLM backend for agent use.
2//!
3//! Implementations wrap `GeminiClient` / `OpenAIClient` existing methods.
4//! `structured_call` injects the schema into the system prompt for flexible parsing.
5
6use crate::tool::ToolDef;
7use crate::types::{Message, Role, SgrError, ToolCall};
8use serde_json::Value;
9
10/// Abstract LLM client for agent framework.
11#[async_trait::async_trait]
12pub trait LlmClient: Send + Sync {
13    /// Structured call: send messages with schema injected into system prompt.
14    /// Returns (parsed_output, native_tool_calls, raw_text).
15    async fn structured_call(
16        &self,
17        messages: &[Message],
18        schema: &Value,
19    ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError>;
20
21    /// Native function calling: send messages + tool defs, get tool calls.
22    /// This is STATELESS — no side effects on shared state.
23    async fn tools_call(
24        &self,
25        messages: &[Message],
26        tools: &[ToolDef],
27    ) -> Result<Vec<ToolCall>, SgrError>;
28
29    /// Stateful function calling with explicit response_id for chaining.
30    /// Returns (tool_calls, new_response_id).
31    /// When previous_response_id is Some, only delta messages are needed.
32    async fn tools_call_stateful(
33        &self,
34        messages: &[Message],
35        tools: &[ToolDef],
36        _previous_response_id: Option<&str>,
37    ) -> Result<(Vec<ToolCall>, Option<String>), SgrError> {
38        // Default: delegate to stateless tools_call, no chaining
39        let calls = self.tools_call(messages, tools).await?;
40        Ok((calls, None))
41    }
42
43    /// Function calling that also returns assistant text content.
44    /// Single-phase agents need both reasoning (text) and actions (tool calls) in one call.
45    /// Default: delegate to tools_call, return empty text.
46    async fn tools_call_with_text(
47        &self,
48        messages: &[Message],
49        tools: &[ToolDef],
50    ) -> Result<(Vec<ToolCall>, String), SgrError> {
51        let calls = self.tools_call(messages, tools).await?;
52        Ok((calls, String::new()))
53    }
54
55    /// Plain text completion (no schema, no tools).
56    async fn complete(&self, messages: &[Message]) -> Result<String, SgrError>;
57}
58
59/// When a model responds with text content instead of tool calls,
60/// synthesize a "finish" tool call so the agent loop gets the answer.
61/// Call this in `tools_call` implementations after extracting tool calls.
62pub fn synthesize_finish_if_empty(calls: &mut Vec<ToolCall>, content: &str) {
63    if calls.is_empty() {
64        let text = content.trim();
65        if !text.is_empty() {
66            calls.push(ToolCall {
67                id: "synth_finish".into(),
68                name: "finish".into(),
69                arguments: serde_json::json!({"summary": text}),
70            });
71        }
72    }
73}
74
75/// Inject schema into messages: append to existing system message or prepend a new one.
76fn inject_schema(messages: &[Message], schema: &Value) -> Vec<Message> {
77    let schema_hint = format!(
78        "\n\nRespond with valid JSON matching this schema:\n{}\n\nDo NOT wrap in markdown code blocks. Output raw JSON only.",
79        serde_json::to_string_pretty(schema).unwrap_or_default()
80    );
81
82    let mut msgs = Vec::with_capacity(messages.len() + 1);
83    let mut injected = false;
84
85    for msg in messages {
86        if msg.role == Role::System && !injected {
87            // Append schema to existing system message
88            msgs.push(Message::system(format!("{}{}", msg.content, schema_hint)));
89            injected = true;
90        } else {
91            msgs.push(msg.clone());
92        }
93    }
94
95    if !injected {
96        // No system message found — prepend one
97        msgs.insert(0, Message::system(schema_hint));
98    }
99
100    msgs
101}
102
103#[cfg(feature = "gemini")]
104mod gemini_impl {
105    use super::*;
106    use crate::gemini::GeminiClient;
107
108    #[async_trait::async_trait]
109    impl LlmClient for GeminiClient {
110        async fn structured_call(
111            &self,
112            messages: &[Message],
113            schema: &Value,
114        ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
115            let msgs = inject_schema(messages, schema);
116            let resp = self.flexible::<Value>(&msgs).await?;
117            Ok((resp.output, resp.tool_calls, resp.raw_text))
118        }
119
120        async fn tools_call(
121            &self,
122            messages: &[Message],
123            tools: &[ToolDef],
124        ) -> Result<Vec<ToolCall>, SgrError> {
125            self.tools_call(messages, tools).await
126        }
127
128        async fn complete(&self, messages: &[Message]) -> Result<String, SgrError> {
129            let resp = self.flexible::<Value>(messages).await?;
130            Ok(resp.raw_text)
131        }
132    }
133}
134
135#[cfg(feature = "openai")]
136mod openai_impl {
137    use super::*;
138    use crate::openai::OpenAIClient;
139
140    #[async_trait::async_trait]
141    impl LlmClient for OpenAIClient {
142        async fn structured_call(
143            &self,
144            messages: &[Message],
145            schema: &Value,
146        ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
147            let msgs = inject_schema(messages, schema);
148            let resp = self.flexible::<Value>(&msgs).await?;
149            Ok((resp.output, resp.tool_calls, resp.raw_text))
150        }
151
152        async fn tools_call(
153            &self,
154            messages: &[Message],
155            tools: &[ToolDef],
156        ) -> Result<Vec<ToolCall>, SgrError> {
157            self.tools_call(messages, tools).await
158        }
159
160        async fn complete(&self, messages: &[Message]) -> Result<String, SgrError> {
161            let resp = self.flexible::<Value>(messages).await?;
162            Ok(resp.raw_text)
163        }
164    }
165}
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170    use crate::tool::ToolDef;
171
172    /// Mock client that only implements the required trait methods.
173    /// tools_call_stateful uses the default impl (delegates to tools_call).
174    struct MockStatelessClient;
175
176    #[async_trait::async_trait]
177    impl LlmClient for MockStatelessClient {
178        async fn structured_call(
179            &self,
180            _: &[Message],
181            _: &Value,
182        ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
183            Ok((None, vec![], String::new()))
184        }
185        async fn tools_call(
186            &self,
187            _: &[Message],
188            _: &[ToolDef],
189        ) -> Result<Vec<ToolCall>, SgrError> {
190            Ok(vec![ToolCall {
191                id: "tc1".into(),
192                name: "test_tool".into(),
193                arguments: serde_json::json!({"x": 1}),
194            }])
195        }
196        async fn complete(&self, _: &[Message]) -> Result<String, SgrError> {
197            Ok(String::new())
198        }
199    }
200
201    #[tokio::test]
202    async fn tools_call_stateful_default_delegates() {
203        let client = MockStatelessClient;
204        let msgs = vec![Message::user("hi")];
205        let tools = vec![ToolDef {
206            name: "test_tool".into(),
207            description: "test".into(),
208            parameters: serde_json::json!({"type": "object"}),
209        }];
210
211        // Default impl delegates to tools_call, returns None for response_id
212        let (calls, response_id) = client
213            .tools_call_stateful(&msgs, &tools, None)
214            .await
215            .unwrap();
216        assert_eq!(calls.len(), 1);
217        assert_eq!(calls[0].name, "test_tool");
218        assert!(response_id.is_none(), "default impl returns no response_id");
219
220        // With previous_response_id — still delegates to stateless, ignores it
221        let (calls, response_id) = client
222            .tools_call_stateful(&msgs, &tools, Some("resp_abc"))
223            .await
224            .unwrap();
225        assert_eq!(calls.len(), 1);
226        assert!(response_id.is_none());
227    }
228
229    #[test]
230    fn inject_schema_appends_to_existing_system() {
231        let msgs = vec![
232            Message::system("You are a coding agent."),
233            Message::user("hello"),
234        ];
235        let schema = serde_json::json!({"type": "object"});
236        let result = inject_schema(&msgs, &schema);
237
238        assert_eq!(result.len(), 2);
239        assert!(result[0].content.contains("You are a coding agent."));
240        assert!(result[0].content.contains("Respond with valid JSON"));
241        assert_eq!(result[0].role, Role::System);
242    }
243
244    #[test]
245    fn inject_schema_prepends_when_no_system() {
246        let msgs = vec![Message::user("hello")];
247        let schema = serde_json::json!({"type": "object"});
248        let result = inject_schema(&msgs, &schema);
249
250        assert_eq!(result.len(), 2);
251        assert_eq!(result[0].role, Role::System);
252        assert!(result[0].content.contains("Respond with valid JSON"));
253        assert_eq!(result[1].role, Role::User);
254    }
255
256    #[test]
257    fn inject_schema_only_first_system_message() {
258        let msgs = vec![
259            Message::system("System 1"),
260            Message::user("msg"),
261            Message::system("System 2"),
262        ];
263        let schema = serde_json::json!({"type": "object"});
264        let result = inject_schema(&msgs, &schema);
265
266        assert_eq!(result.len(), 3);
267        // First system gets schema
268        assert!(result[0].content.contains("Respond with valid JSON"));
269        // Second system unchanged
270        assert_eq!(result[2].content, "System 2");
271    }
272}