Skip to main content

otherone_ai/
lib.rs

1// 作用:otherone-ai 模块入口 — 统一导出 AI 提供商相关类型和方法
2// 关联:被 otherone-agent 和 otherone-context 调用
3// 预期结果:提供 invoke_model 入口,按 provider 类型分发到对应客户端
4
5pub mod error;
6pub mod traits;
7pub mod types;
8
9pub mod anthropic;
10pub mod fetch;
11pub mod local;
12pub mod openai;
13pub mod openrouter;
14
15use error::AiError;
16use traits::AiProvider;
17use types::{ChatRequest, ChatResponse, ProviderType};
18
19/// 从 OpenAI 风格的 config JSON 构建 ChatRequest
20/// OpenRouter 和 Local provider 都使用与 OpenAI 兼容的格式
21fn build_openai_style_request(options: &openai::types::ConfigOptions) -> ChatRequest {
22    ChatRequest {
23        model: options.model.clone(),
24        messages: options.messages.clone(),
25        max_tokens: options.max_tokens.or(options.context_length),
26        temperature: options.temperature,
27        top_p: options.top_p,
28        tools: options.tools.clone(),
29        tool_choice: options.tool_choice.clone(),
30        stream: options.stream,
31        extra: if options.extra.is_empty() {
32            None
33        } else {
34            Some(options.extra.clone())
35        },
36    }
37}
38
39/// 调用 AI 模型(非流式)
40pub async fn invoke_model(
41    provider: ProviderType,
42    api_key: &str,
43    base_url: &str,
44    config: serde_json::Value,
45) -> Result<ChatResponse, AiError> {
46    if api_key.is_empty() {
47        return Err(AiError::ConfigError("api_key is required".to_string()));
48    }
49    if base_url.is_empty() {
50        return Err(AiError::ConfigError("base_url is required".to_string()));
51    }
52
53    match provider {
54        ProviderType::OpenAI => {
55            let options: openai::types::ConfigOptions = serde_json::from_value(config)
56                .map_err(|e| AiError::ConfigError(format!("Invalid OpenAI config: {}", e)))?;
57            let request = build_openai_style_request(&options);
58            let client =
59                openai::client::OpenAiClient::new(api_key.to_string(), base_url.to_string());
60            client.chat(request).await
61        }
62        ProviderType::OpenRouter => {
63            let options: openai::types::ConfigOptions = serde_json::from_value(config)
64                .map_err(|e| AiError::ConfigError(format!("Invalid OpenRouter config: {}", e)))?;
65            let request = build_openai_style_request(&options);
66            let client = openrouter::client::OpenRouterClient::new(
67                api_key.to_string(),
68                base_url.to_string(),
69            );
70            client.chat(request).await
71        }
72        ProviderType::Local => {
73            let options: openai::types::ConfigOptions = serde_json::from_value(config)
74                .map_err(|e| AiError::ConfigError(format!("Invalid Local config: {}", e)))?;
75            let request = build_openai_style_request(&options);
76            let client = local::client::LocalClient::new(api_key.to_string(), base_url.to_string());
77            client.chat(request).await
78        }
79        ProviderType::Anthropic => {
80            let options: anthropic::types::ConfigOptions = serde_json::from_value(config)
81                .map_err(|e| AiError::ConfigError(format!("Invalid Anthropic config: {}", e)))?;
82            let request = ChatRequest {
83                model: options.model,
84                messages: options.messages,
85                max_tokens: options.max_tokens.or(Some(4096)),
86                temperature: options.temperature,
87                top_p: options.top_p,
88                tools: options.tools,
89                tool_choice: None,
90                stream: options.stream,
91                extra: None,
92            };
93            let client =
94                anthropic::client::AnthropicClient::new(api_key.to_string(), base_url.to_string());
95            client.chat(request).await
96        }
97        ProviderType::Fetch => {
98            let request: ChatRequest = serde_json::from_value(config)
99                .map_err(|e| AiError::ConfigError(format!("Invalid Fetch config: {}", e)))?;
100            let client = fetch::client::FetchClient::new(api_key.to_string(), base_url.to_string());
101            client.chat(request).await
102        }
103    }
104}
105
106/// 调用 AI 模型的流式版本
107pub async fn invoke_model_stream(
108    provider: ProviderType,
109    api_key: &str,
110    base_url: &str,
111    config: serde_json::Value,
112) -> Result<traits::ChatStream, AiError> {
113    if api_key.is_empty() {
114        return Err(AiError::ConfigError("api_key is required".to_string()));
115    }
116    if base_url.is_empty() {
117        return Err(AiError::ConfigError("base_url is required".to_string()));
118    }
119
120    match provider {
121        ProviderType::OpenAI => {
122            let options: openai::types::ConfigOptions = serde_json::from_value(config)
123                .map_err(|e| AiError::ConfigError(format!("Invalid OpenAI config: {}", e)))?;
124            let mut request = build_openai_style_request(&options);
125            request.stream = Some(true);
126            let client =
127                openai::client::OpenAiClient::new(api_key.to_string(), base_url.to_string());
128            client.chat_stream(request).await
129        }
130        ProviderType::OpenRouter => {
131            let options: openai::types::ConfigOptions = serde_json::from_value(config)
132                .map_err(|e| AiError::ConfigError(format!("Invalid OpenRouter config: {}", e)))?;
133            let mut request = build_openai_style_request(&options);
134            request.stream = Some(true);
135            let client = openrouter::client::OpenRouterClient::new(
136                api_key.to_string(),
137                base_url.to_string(),
138            );
139            client.chat_stream(request).await
140        }
141        ProviderType::Local => {
142            let options: openai::types::ConfigOptions = serde_json::from_value(config)
143                .map_err(|e| AiError::ConfigError(format!("Invalid Local config: {}", e)))?;
144            let mut request = build_openai_style_request(&options);
145            request.stream = Some(true);
146            let client = local::client::LocalClient::new(api_key.to_string(), base_url.to_string());
147            client.chat_stream(request).await
148        }
149        ProviderType::Anthropic => {
150            let options: anthropic::types::ConfigOptions = serde_json::from_value(config)
151                .map_err(|e| AiError::ConfigError(format!("Invalid Anthropic config: {}", e)))?;
152            let request = ChatRequest {
153                model: options.model,
154                messages: options.messages,
155                max_tokens: options.max_tokens.or(Some(4096)),
156                temperature: options.temperature,
157                top_p: options.top_p,
158                tools: options.tools,
159                tool_choice: None,
160                stream: Some(true),
161                extra: None,
162            };
163            let client =
164                anthropic::client::AnthropicClient::new(api_key.to_string(), base_url.to_string());
165            client.chat_stream(request).await
166        }
167        ProviderType::Fetch => {
168            let request: ChatRequest = serde_json::from_value(config)
169                .map_err(|e| AiError::ConfigError(format!("Invalid Fetch config: {}", e)))?;
170            let client = fetch::client::FetchClient::new(api_key.to_string(), base_url.to_string());
171            client.chat_stream(request).await
172        }
173    }
174}
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179
180    #[tokio::test]
181    async fn test_invoke_model_empty_api_key() {
182        let result = invoke_model(
183            ProviderType::OpenAI,
184            "",
185            "https://api.openai.com/v1",
186            serde_json::json!({}),
187        )
188        .await;
189        assert!(result.is_err());
190    }
191
192    #[tokio::test]
193    async fn test_invoke_model_empty_base_url() {
194        let result =
195            invoke_model(ProviderType::OpenAI, "test-key", "", serde_json::json!({})).await;
196        assert!(result.is_err());
197    }
198
199    #[test]
200    fn openai_style_request_preserves_extra_chat_params() {
201        let options: openai::types::ConfigOptions = serde_json::from_value(serde_json::json!({
202            "model": "deepseek-v4-flash",
203            "messages": [],
204            "reasoning_effort": "low",
205            "max_tokens": 16
206        }))
207        .unwrap();
208
209        let request = build_openai_style_request(&options);
210        assert_eq!(request.max_tokens, Some(16));
211        assert_eq!(
212            request.extra.as_ref().unwrap()["reasoning_effort"],
213            serde_json::json!("low")
214        );
215    }
216
217    #[tokio::test]
218    #[ignore]
219    async fn live_openai_compatible_stream_returns_first_delta_chunk() {
220        use futures::StreamExt;
221
222        let api_key = std::env::var("OTHERONE_LIVE_API_KEY")
223            .expect("OTHERONE_LIVE_API_KEY is required for live stream test");
224        let base_url = std::env::var("OTHERONE_LIVE_BASE_URL")
225            .unwrap_or_else(|_| "https://api.deepseek.com/v1".to_string());
226        let model = std::env::var("OTHERONE_LIVE_MODEL")
227            .unwrap_or_else(|_| "deepseek-v4-flash".to_string());
228
229        let mut stream = invoke_model_stream(
230            ProviderType::OpenAI,
231            &api_key,
232            &base_url,
233            serde_json::json!({
234                "model": model,
235                "messages": [
236                    { "role": "user", "content": "Reply with only OK." }
237                ],
238                "max_tokens": 16,
239                "stream": true
240            }),
241        )
242        .await
243        .unwrap();
244
245        let first_delta = tokio::time::timeout(std::time::Duration::from_secs(30), async {
246            while let Some(chunk) = stream.next().await {
247                let chunk = chunk.unwrap();
248                if let Some(delta_text) = chunk
249                    .choices
250                    .first()
251                    .and_then(|choice| choice.delta.as_ref())
252                    .and_then(|delta| {
253                        delta
254                            .content
255                            .as_deref()
256                            .or(delta.reasoning_content.as_deref())
257                    })
258                    .filter(|content| !content.is_empty())
259                {
260                    return Some(delta_text.to_string());
261                }
262            }
263            None
264        })
265        .await
266        .unwrap();
267
268        assert!(first_delta.is_some());
269    }
270}