Skip to main content

nenjo_models/
lib.rs

1//! # nenjo-providers
2//!
3//! LLM provider trait, types, and implementations for the Nenjo agent platform.
4//!
5//! This crate provides:
6//! - The [`ModelProvider`] trait for LLM integrations
7//! - Message types: [`ChatMessage`], [`ChatRequest`], [`ChatResponse`], [`ToolCall`]
8//! - Provider implementations: Anthropic, OpenAI, Gemini, Ollama, OpenRouter, and
9//!   OpenAI-compatible providers
10//! - Reliability wrappers: [`ReliableProvider`] (retry/fallback), [`RouterProvider`] (model routing)
11
12pub mod anthropic;
13pub mod compatible;
14pub mod gemini;
15pub mod native;
16pub mod ollama;
17pub mod openai;
18pub mod openrouter;
19pub mod reliable;
20pub mod router;
21pub mod traits;
22pub mod xai;
23
24// Re-export core types at crate root.
25pub use native::{
26    EditImageRequest, EditVideoRequest, ExtendVideoRequest, GenerateImageRequest,
27    GenerateSpeechRequest, GenerateVideoRequest, ImageToVideoRequest, MediaInputAsset,
28    MediaOutputAsset, MediaOutputFormat, ModelNativeCapabilities, NativeCapabilitiesProvider,
29    NativeExecutionMode, NativeMediaJob, NativeMediaJobStatus, NativeMediaRequest,
30    NativeMediaResponse, NativeModelToolId, NativeOperation, NativeToolSpec,
31    ProviderNativeCapabilities, ProviderNativeModelToolSpec, ReferenceToVideoRequest,
32    TranscribeAudioRequest,
33};
34pub use nenjo_tool_api::{sanitize_tool_name, sanitize_tool_name_lenient};
35pub use traits::{
36    ChatMessage, ChatRequest, ChatResponse, ConversationMessage, ModelProvider,
37    ProviderStreamEvent, ProviderToolTrace, TokenUsage, ToolCall, ToolCategory, ToolResultMessage,
38    ToolSpec, one_shot,
39};
40
41// Re-export provider implementations.
42pub use anthropic::AnthropicProvider;
43pub use compatible::{AuthStyle, OpenAiCompatibleProvider};
44pub use gemini::GeminiProvider;
45pub use ollama::OllamaProvider;
46pub use openai::OpenAiProvider;
47pub use openrouter::OpenRouterProvider;
48pub use reliable::ReliableProvider;
49pub use router::RouterProvider;
50pub use xai::{XAI_DEFAULT_BASE_URL, XAiProvider};
51
52use std::sync::Arc;
53
54use anyhow::Result;
55
56/// Maps a model provider name (for example, `"openai"` or `"anthropic"`) to
57/// an LLM provider implementation.
58///
59/// Implementations are responsible for API key resolution and any runtime
60/// configuration needed to construct concrete [`ModelProvider`] instances.
61pub trait ModelProviderFactory: Send + Sync {
62    fn create(&self, provider_name: &str) -> Result<Arc<dyn ModelProvider>>;
63
64    /// Create a provider with an optional base URL override.
65    ///
66    /// Used for self-hosted or OpenAI-compatible providers where the caller
67    /// configures a custom endpoint. The default implementation ignores the
68    /// URL and delegates to [`create`](Self::create).
69    fn create_with_base_url(
70        &self,
71        provider_name: &str,
72        base_url: Option<&str>,
73    ) -> Result<Arc<dyn ModelProvider>> {
74        let _ = base_url;
75        self.create(provider_name)
76    }
77}
78
79impl<T> ModelProviderFactory for Arc<T>
80where
81    T: ModelProviderFactory + ?Sized,
82{
83    fn create(&self, provider_name: &str) -> Result<Arc<dyn ModelProvider>> {
84        self.as_ref().create(provider_name)
85    }
86
87    fn create_with_base_url(
88        &self,
89        provider_name: &str,
90        base_url: Option<&str>,
91    ) -> Result<Arc<dyn ModelProvider>> {
92        self.as_ref().create_with_base_url(provider_name, base_url)
93    }
94}
95
96/// Typed variant of [`ModelProviderFactory`] using a generic associated model
97/// provider type.
98///
99/// The lifetime parameter leaves room for factories that return providers
100/// borrowing factory-owned shared state, while today's blanket implementation
101/// preserves the existing `Arc<dyn ModelProvider>` behavior.
102pub trait TypedModelProviderFactory: Send + Sync {
103    type Provider<'a>: ModelProvider + Send + Sync + ?Sized + 'a
104    where
105        Self: 'a;
106
107    fn create_typed(&self, provider_name: &str) -> Result<Arc<Self::Provider<'static>>>;
108
109    fn create_typed_with_base_url(
110        &self,
111        provider_name: &str,
112        base_url: Option<&str>,
113    ) -> Result<Arc<Self::Provider<'static>>> {
114        let _ = base_url;
115        self.create_typed(provider_name)
116    }
117}
118
119impl<T> TypedModelProviderFactory for T
120where
121    T: ModelProviderFactory + ?Sized + 'static,
122{
123    type Provider<'a>
124        = dyn ModelProvider + 'static
125    where
126        Self: 'a;
127
128    fn create_typed(&self, provider_name: &str) -> Result<Arc<Self::Provider<'static>>> {
129        self.create(provider_name)
130    }
131
132    fn create_typed_with_base_url(
133        &self,
134        provider_name: &str,
135        base_url: Option<&str>,
136    ) -> Result<Arc<Self::Provider<'static>>> {
137        self.create_with_base_url(provider_name, base_url)
138    }
139}
140
141// ── Thinking/reasoning helpers ───────────────────────────────────
142
143/// Strip `<think>…</think>` blocks from model output.
144///
145/// Reasoning models (DeepSeek-reasoner, MiniMax, etc.) emit chain-of-thought
146/// wrapped in `<think>` tags. This content is large, not useful for downstream
147/// consumers, and wastes bandwidth on the event bus. Call this on
148/// `ChatResponse.text` before the text enters the message history or event
149/// stream.
150pub fn strip_thinking(text: &str) -> String {
151    let mut result = String::with_capacity(text.len());
152    let mut remaining = text;
153
154    while let Some(start) = remaining.find("<think>") {
155        result.push_str(&remaining[..start]);
156        if let Some(end) = remaining[start..].find("</think>") {
157            remaining = &remaining[start + end + "</think>".len()..];
158        } else {
159            // Unclosed <think> tag — drop everything after it
160            return result.trim().to_string();
161        }
162    }
163    result.push_str(remaining);
164    result.trim().to_string()
165}
166
167// ── Error helpers ───────────────────────────────────────────────
168
169const MAX_API_ERROR_CHARS: usize = 200;
170
171fn is_secret_char(c: char) -> bool {
172    c.is_ascii_alphanumeric() || matches!(c, '-' | '_' | '.' | ':')
173}
174
175fn token_end(input: &str, from: usize) -> usize {
176    let mut end = from;
177    for (i, c) in input[from..].char_indices() {
178        if is_secret_char(c) {
179            end = from + i + c.len_utf8();
180        } else {
181            break;
182        }
183    }
184    end
185}
186
187/// Scrub known secret-like token prefixes from provider error strings.
188pub fn scrub_secret_patterns(input: &str) -> String {
189    const PREFIXES: [&str; 3] = ["sk-", "xoxb-", "xoxp-"];
190    let mut scrubbed = input.to_string();
191    for prefix in PREFIXES {
192        let mut search_from = 0;
193        loop {
194            let Some(rel) = scrubbed[search_from..].find(prefix) else {
195                break;
196            };
197            let start = search_from + rel;
198            let content_start = start + prefix.len();
199            let end = token_end(&scrubbed, content_start);
200            if end == content_start {
201                search_from = content_start;
202                continue;
203            }
204            scrubbed.replace_range(start..end, "[REDACTED]");
205            search_from = start + "[REDACTED]".len();
206        }
207    }
208    scrubbed
209}
210
211/// Sanitize API error text by scrubbing secrets and truncating length.
212pub fn sanitize_api_error(input: &str) -> String {
213    let scrubbed = scrub_secret_patterns(input);
214    if scrubbed.chars().count() <= MAX_API_ERROR_CHARS {
215        return scrubbed;
216    }
217    let mut end = MAX_API_ERROR_CHARS;
218    while end > 0 && !scrubbed.is_char_boundary(end) {
219        end -= 1;
220    }
221    format!("{}...", &scrubbed[..end])
222}
223
224/// Build a sanitized provider error from a failed HTTP response.
225pub async fn api_error(provider: &str, response: reqwest::Response) -> anyhow::Error {
226    let status = response.status();
227    let body = response
228        .text()
229        .await
230        .unwrap_or_else(|_| "<failed to read provider error body>".to_string());
231    let sanitized = sanitize_api_error(&body);
232    anyhow::anyhow!("{provider} API error ({status}): {sanitized}")
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238
239    #[test]
240    fn strip_thinking_removes_think_block() {
241        let input = "<think>\nLet me reason about this...\n</think>\nHello!";
242        assert_eq!(strip_thinking(input), "Hello!");
243    }
244
245    #[test]
246    fn strip_thinking_multiple_blocks() {
247        let input = "<think>first</think>A<think>second</think>B";
248        assert_eq!(strip_thinking(input), "AB");
249    }
250
251    #[test]
252    fn strip_thinking_no_tags() {
253        assert_eq!(strip_thinking("Just regular text"), "Just regular text");
254    }
255
256    #[test]
257    fn strip_thinking_empty_think_block() {
258        assert_eq!(strip_thinking("<think></think>Hello"), "Hello");
259    }
260
261    #[test]
262    fn strip_thinking_unclosed_tag() {
263        let input = "Before<think>reasoning that never ends...";
264        assert_eq!(strip_thinking(input), "Before");
265    }
266
267    #[test]
268    fn strip_thinking_only_thinking() {
269        let input = "<think>All reasoning, no output</think>";
270        assert_eq!(strip_thinking(input), "");
271    }
272}