1use async_trait::async_trait;
2pub use nenjo_tool_api::{ToolCall, ToolCategory, ToolResultMessage, ToolSpec};
3use serde::{Deserialize, Serialize};
4
5use crate::native::{
6 NativeMediaJob, NativeMediaRequest, NativeMediaResponse, NativeModelToolId,
7 ProviderNativeCapabilities,
8};
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct ChatMessage {
13 pub role: String,
14 pub content: String,
15}
16
17impl ChatMessage {
18 pub fn system(content: impl Into<String>) -> Self {
19 Self {
20 role: "system".into(),
21 content: content.into(),
22 }
23 }
24
25 pub fn user(content: impl Into<String>) -> Self {
26 Self {
27 role: "user".into(),
28 content: content.into(),
29 }
30 }
31
32 pub fn assistant(content: impl Into<String>) -> Self {
33 Self {
34 role: "assistant".into(),
35 content: content.into(),
36 }
37 }
38
39 pub fn tool(content: impl Into<String>) -> Self {
40 Self {
41 role: "tool".into(),
42 content: content.into(),
43 }
44 }
45
46 pub fn developer(content: impl Into<String>) -> Self {
47 Self {
48 role: "developer".into(),
49 content: content.into(),
50 }
51 }
52}
53
54#[derive(Debug, Clone, Default)]
56pub struct TokenUsage {
57 pub input_tokens: u64,
58 pub output_tokens: u64,
59}
60
61#[derive(Debug, Clone)]
66pub struct ProviderToolTrace {
67 pub id: String,
68 pub name: String,
69 pub provider: String,
70 pub input: serde_json::Value,
71 pub output: Option<serde_json::Value>,
72 pub citations: Vec<serde_json::Value>,
73}
74
75#[derive(Debug, Clone)]
77pub struct ChatResponse {
78 pub text: Option<String>,
80 pub tool_calls: Vec<ToolCall>,
82 pub provider_tool_calls: Vec<ProviderToolTrace>,
84 pub usage: TokenUsage,
86}
87
88#[derive(Debug, Clone)]
94pub enum ProviderStreamEvent {
95 TextDelta(String),
96 ProviderToolStarted(ProviderToolTrace),
97 ProviderToolCompleted(ProviderToolTrace),
98}
99
100impl ChatResponse {
101 pub fn has_tool_calls(&self) -> bool {
103 !self.tool_calls.is_empty()
104 }
105
106 pub fn text_or_empty(&self) -> &str {
108 self.text.as_deref().unwrap_or("")
109 }
110}
111
112#[derive(Debug, Clone, Copy)]
114pub struct ChatRequest<'a> {
115 pub messages: &'a [ChatMessage],
116 pub tools: Option<&'a [ToolSpec]>,
117 pub native_tools: Option<&'a [NativeModelToolId]>,
118}
119
120#[derive(Debug, Clone, Serialize, Deserialize)]
122#[serde(tag = "type", content = "data")]
123pub enum ConversationMessage {
124 Chat(ChatMessage),
126 AssistantToolCalls {
128 text: Option<String>,
129 tool_calls: Vec<ToolCall>,
130 },
131 ToolResults(Vec<ToolResultMessage>),
133}
134
135#[async_trait]
136pub trait ModelProvider: Send + Sync {
137 async fn chat(
142 &self,
143 request: ChatRequest<'_>,
144 model: &str,
145 temperature: f64,
146 ) -> anyhow::Result<ChatResponse>;
147
148 async fn chat_stream(
154 &self,
155 request: ChatRequest<'_>,
156 model: &str,
157 temperature: f64,
158 events: tokio::sync::mpsc::UnboundedSender<ProviderStreamEvent>,
159 ) -> anyhow::Result<ChatResponse> {
160 let _ = events;
161 self.chat(request, model, temperature).await
162 }
163
164 fn context_window(&self, _model: &str) -> Option<usize> {
170 None
171 }
172
173 fn supports_native_tools(&self) -> bool {
175 false
176 }
177
178 fn supports_developer_role(&self, _model: &str) -> bool {
182 false
183 }
184
185 fn native_capabilities(&self) -> Option<ProviderNativeCapabilities> {
190 None
191 }
192
193 async fn submit_media(
195 &self,
196 request: NativeMediaRequest,
197 ) -> anyhow::Result<NativeMediaResponse> {
198 anyhow::bail!(
199 "provider does not support native media operation {:?}",
200 request.operation()
201 )
202 }
203
204 async fn poll_media_job(&self, job: &NativeMediaJob) -> anyhow::Result<NativeMediaResponse> {
206 let _ = job;
207 anyhow::bail!("provider does not support polling native media jobs")
208 }
209
210 async fn warmup(&self) -> anyhow::Result<()> {
213 Ok(())
214 }
215}
216
217pub async fn one_shot(
220 provider: &dyn ModelProvider,
221 system: Option<&str>,
222 message: &str,
223 model: &str,
224 temperature: f64,
225) -> anyhow::Result<String> {
226 let mut messages = Vec::new();
227 if let Some(sys) = system {
228 if provider.supports_developer_role(model) {
229 messages.push(ChatMessage::developer(sys));
230 } else {
231 messages.push(ChatMessage::system(sys));
232 }
233 }
234 messages.push(ChatMessage::user(message));
235 let request = ChatRequest {
236 messages: &messages,
237 tools: None,
238 native_tools: None,
239 };
240 let response = provider.chat(request, model, temperature).await?;
241 Ok(response.text.unwrap_or_default())
242}
243
244#[cfg(test)]
245mod tests {
246 use super::*;
247
248 #[test]
249 fn chat_message_constructors() {
250 let sys = ChatMessage::system("Be helpful");
251 assert_eq!(sys.role, "system");
252 assert_eq!(sys.content, "Be helpful");
253
254 let user = ChatMessage::user("Hello");
255 assert_eq!(user.role, "user");
256
257 let asst = ChatMessage::assistant("Hi there");
258 assert_eq!(asst.role, "assistant");
259
260 let tool = ChatMessage::tool("{}");
261 assert_eq!(tool.role, "tool");
262
263 let dev = ChatMessage::developer("Follow these instructions");
264 assert_eq!(dev.role, "developer");
265 assert_eq!(dev.content, "Follow these instructions");
266 }
267
268 #[test]
269 fn chat_response_helpers() {
270 let empty = ChatResponse {
271 text: None,
272 tool_calls: vec![],
273 provider_tool_calls: vec![],
274 usage: TokenUsage::default(),
275 };
276 assert!(!empty.has_tool_calls());
277 assert_eq!(empty.text_or_empty(), "");
278
279 let with_tools = ChatResponse {
280 text: Some("Let me check".into()),
281 tool_calls: vec![ToolCall {
282 id: "1".into(),
283 name: "shell".into(),
284 arguments: "{}".into(),
285 }],
286 provider_tool_calls: vec![],
287 usage: TokenUsage::default(),
288 };
289 assert!(with_tools.has_tool_calls());
290 assert_eq!(with_tools.text_or_empty(), "Let me check");
291 }
292
293 #[test]
294 fn tool_call_serialization() {
295 let tc = ToolCall {
296 id: "call_123".into(),
297 name: "file_read".into(),
298 arguments: r#"{"path":"test.txt"}"#.into(),
299 };
300 let json = serde_json::to_string(&tc).unwrap();
301 assert!(json.contains("call_123"));
302 assert!(json.contains("file_read"));
303 }
304
305 #[test]
306 fn conversation_message_variants() {
307 let chat = ConversationMessage::Chat(ChatMessage::user("hi"));
308 let json = serde_json::to_string(&chat).unwrap();
309 assert!(json.contains("\"type\":\"Chat\""));
310
311 let tool_result = ConversationMessage::ToolResults(vec![ToolResultMessage {
312 tool_call_id: "1".into(),
313 content: "done".into(),
314 }]);
315 let json = serde_json::to_string(&tool_result).unwrap();
316 assert!(json.contains("\"type\":\"ToolResults\""));
317 }
318}