Skip to main content

nenjo_models/
traits.rs

1use async_trait::async_trait;
2pub use nenjo_tool_api::{ToolCall, ToolCategory, ToolResultMessage, ToolSpec};
3use serde::{Deserialize, Serialize};
4
5use crate::native::{
6    NativeMediaJob, NativeMediaRequest, NativeMediaResponse, NativeModelToolId,
7    ProviderNativeCapabilities,
8};
9
10/// A single message in a conversation.
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct ChatMessage {
13    pub role: String,
14    pub content: String,
15}
16
17impl ChatMessage {
18    pub fn system(content: impl Into<String>) -> Self {
19        Self {
20            role: "system".into(),
21            content: content.into(),
22        }
23    }
24
25    pub fn user(content: impl Into<String>) -> Self {
26        Self {
27            role: "user".into(),
28            content: content.into(),
29        }
30    }
31
32    pub fn assistant(content: impl Into<String>) -> Self {
33        Self {
34            role: "assistant".into(),
35            content: content.into(),
36        }
37    }
38
39    pub fn tool(content: impl Into<String>) -> Self {
40        Self {
41            role: "tool".into(),
42            content: content.into(),
43        }
44    }
45
46    pub fn developer(content: impl Into<String>) -> Self {
47        Self {
48            role: "developer".into(),
49            content: content.into(),
50        }
51    }
52}
53
54/// Token usage reported by the LLM provider.
55#[derive(Debug, Clone, Default)]
56pub struct TokenUsage {
57    pub input_tokens: u64,
58    pub output_tokens: u64,
59}
60
61/// A provider-executed tool call observed inside a model response.
62///
63/// These traces are informational only. They must not be fed to the local tool
64/// executor because the provider has already executed the tool server-side.
65#[derive(Debug, Clone)]
66pub struct ProviderToolTrace {
67    pub id: String,
68    pub name: String,
69    pub provider: String,
70    pub input: serde_json::Value,
71    pub output: Option<serde_json::Value>,
72    pub citations: Vec<serde_json::Value>,
73}
74
75/// An LLM response that may contain text, tool calls, or both.
76#[derive(Debug, Clone)]
77pub struct ChatResponse {
78    /// Text content of the response (may be empty if only tool calls).
79    pub text: Option<String>,
80    /// Tool calls requested by the LLM for the local runtime to execute.
81    pub tool_calls: Vec<ToolCall>,
82    /// Provider-executed tool calls observed in the model response.
83    pub provider_tool_calls: Vec<ProviderToolTrace>,
84    /// Token usage reported by the provider (zeros when not available).
85    pub usage: TokenUsage,
86}
87
88/// Incremental events emitted while a provider-native model request is running.
89///
90/// These events are provider-agnostic and intentionally lossy: they carry the
91/// information the worker needs to update live activity without baking a single
92/// vendor's raw streaming schema into the turn loop.
93#[derive(Debug, Clone)]
94pub enum ProviderStreamEvent {
95    TextDelta(String),
96    ProviderToolStarted(ProviderToolTrace),
97    ProviderToolCompleted(ProviderToolTrace),
98}
99
100impl ChatResponse {
101    /// True when the LLM wants to invoke at least one tool.
102    pub fn has_tool_calls(&self) -> bool {
103        !self.tool_calls.is_empty()
104    }
105
106    /// Convenience: return text content or empty string.
107    pub fn text_or_empty(&self) -> &str {
108        self.text.as_deref().unwrap_or("")
109    }
110}
111
112/// Request payload for provider chat calls.
113#[derive(Debug, Clone, Copy)]
114pub struct ChatRequest<'a> {
115    pub messages: &'a [ChatMessage],
116    pub tools: Option<&'a [ToolSpec]>,
117    pub native_tools: Option<&'a [NativeModelToolId]>,
118}
119
120/// A message in a multi-turn conversation, including tool interactions.
121#[derive(Debug, Clone, Serialize, Deserialize)]
122#[serde(tag = "type", content = "data")]
123pub enum ConversationMessage {
124    /// Regular chat message (system, user, assistant).
125    Chat(ChatMessage),
126    /// Tool calls from the assistant (stored for history fidelity).
127    AssistantToolCalls {
128        text: Option<String>,
129        tool_calls: Vec<ToolCall>,
130    },
131    /// Results of tool executions, fed back to the LLM.
132    ToolResults(Vec<ToolResultMessage>),
133}
134
135#[async_trait]
136pub trait ModelProvider: Send + Sync {
137    /// Structured chat API — the single required method.
138    ///
139    /// Accepts a full conversation (system + user + assistant + tool messages)
140    /// plus optional tool definitions. Returns text and/or tool calls.
141    async fn chat(
142        &self,
143        request: ChatRequest<'_>,
144        model: &str,
145        temperature: f64,
146    ) -> anyhow::Result<ChatResponse>;
147
148    /// Optional streaming chat API.
149    ///
150    /// Providers that can surface incremental model or provider-native tool
151    /// progress should override this. The default implementation preserves the
152    /// existing non-streaming behavior.
153    async fn chat_stream(
154        &self,
155        request: ChatRequest<'_>,
156        model: &str,
157        temperature: f64,
158        events: tokio::sync::mpsc::UnboundedSender<ProviderStreamEvent>,
159    ) -> anyhow::Result<ChatResponse> {
160        let _ = events;
161        self.chat(request, model, temperature).await
162    }
163
164    /// Context window size in tokens for the given model.
165    ///
166    /// Providers return the raw advertised context window. The turn loop
167    /// applies its own safety margin. Returns `None` if the model is
168    /// unknown; the turn loop falls back to a conservative default.
169    fn context_window(&self, _model: &str) -> Option<usize> {
170        None
171    }
172
173    /// Whether provider supports native tool calls over API.
174    fn supports_native_tools(&self) -> bool {
175        false
176    }
177
178    /// Whether the given model supports the `developer` message role (OpenAI-spec).
179    /// When true, app-owned instructions are sent as a developer message.
180    /// When false, they are folded into the provider's system-equivalent role.
181    fn supports_developer_role(&self, _model: &str) -> bool {
182        false
183    }
184
185    /// Provider-native capabilities outside the chat/tool turn loop.
186    ///
187    /// Examples include direct image generation, async video rendering,
188    /// text-to-speech, and speech-to-text endpoints.
189    fn native_capabilities(&self) -> Option<ProviderNativeCapabilities> {
190        None
191    }
192
193    /// Submit a provider-native media operation.
194    async fn submit_media(
195        &self,
196        request: NativeMediaRequest,
197    ) -> anyhow::Result<NativeMediaResponse> {
198        anyhow::bail!(
199            "provider does not support native media operation {:?}",
200            request.operation()
201        )
202    }
203
204    /// Poll an async provider-native media job.
205    async fn poll_media_job(&self, job: &NativeMediaJob) -> anyhow::Result<NativeMediaResponse> {
206        let _ = job;
207        anyhow::bail!("provider does not support polling native media jobs")
208    }
209
210    /// Warm up the HTTP connection pool (TLS handshake, DNS, HTTP/2 setup).
211    /// Default implementation is a no-op; providers with HTTP clients should override.
212    async fn warmup(&self) -> anyhow::Result<()> {
213        Ok(())
214    }
215}
216
217/// One-shot helper: builds a ChatRequest from system + user message, calls chat(),
218/// and returns just the text. Used by memory manager and tests.
219pub async fn one_shot(
220    provider: &dyn ModelProvider,
221    system: Option<&str>,
222    message: &str,
223    model: &str,
224    temperature: f64,
225) -> anyhow::Result<String> {
226    let mut messages = Vec::new();
227    if let Some(sys) = system {
228        if provider.supports_developer_role(model) {
229            messages.push(ChatMessage::developer(sys));
230        } else {
231            messages.push(ChatMessage::system(sys));
232        }
233    }
234    messages.push(ChatMessage::user(message));
235    let request = ChatRequest {
236        messages: &messages,
237        tools: None,
238        native_tools: None,
239    };
240    let response = provider.chat(request, model, temperature).await?;
241    Ok(response.text.unwrap_or_default())
242}
243
244#[cfg(test)]
245mod tests {
246    use super::*;
247
248    #[test]
249    fn chat_message_constructors() {
250        let sys = ChatMessage::system("Be helpful");
251        assert_eq!(sys.role, "system");
252        assert_eq!(sys.content, "Be helpful");
253
254        let user = ChatMessage::user("Hello");
255        assert_eq!(user.role, "user");
256
257        let asst = ChatMessage::assistant("Hi there");
258        assert_eq!(asst.role, "assistant");
259
260        let tool = ChatMessage::tool("{}");
261        assert_eq!(tool.role, "tool");
262
263        let dev = ChatMessage::developer("Follow these instructions");
264        assert_eq!(dev.role, "developer");
265        assert_eq!(dev.content, "Follow these instructions");
266    }
267
268    #[test]
269    fn chat_response_helpers() {
270        let empty = ChatResponse {
271            text: None,
272            tool_calls: vec![],
273            provider_tool_calls: vec![],
274            usage: TokenUsage::default(),
275        };
276        assert!(!empty.has_tool_calls());
277        assert_eq!(empty.text_or_empty(), "");
278
279        let with_tools = ChatResponse {
280            text: Some("Let me check".into()),
281            tool_calls: vec![ToolCall {
282                id: "1".into(),
283                name: "shell".into(),
284                arguments: "{}".into(),
285            }],
286            provider_tool_calls: vec![],
287            usage: TokenUsage::default(),
288        };
289        assert!(with_tools.has_tool_calls());
290        assert_eq!(with_tools.text_or_empty(), "Let me check");
291    }
292
293    #[test]
294    fn tool_call_serialization() {
295        let tc = ToolCall {
296            id: "call_123".into(),
297            name: "file_read".into(),
298            arguments: r#"{"path":"test.txt"}"#.into(),
299        };
300        let json = serde_json::to_string(&tc).unwrap();
301        assert!(json.contains("call_123"));
302        assert!(json.contains("file_read"));
303    }
304
305    #[test]
306    fn conversation_message_variants() {
307        let chat = ConversationMessage::Chat(ChatMessage::user("hi"));
308        let json = serde_json::to_string(&chat).unwrap();
309        assert!(json.contains("\"type\":\"Chat\""));
310
311        let tool_result = ConversationMessage::ToolResults(vec![ToolResultMessage {
312            tool_call_id: "1".into(),
313            content: "done".into(),
314        }]);
315        let json = serde_json::to_string(&tool_result).unwrap();
316        assert!(json.contains("\"type\":\"ToolResults\""));
317    }
318}