rig_extra/
agent_variant.rs

1use futures::{Stream, StreamExt};
2use serde::de::DeserializeOwned;
3use std::pin::Pin;
4
5use rig::agent::{Agent, MultiTurnStreamItem};
6use rig::completion::{Message, Prompt, PromptError};
7use rig::providers::{
8    anthropic, azure, cohere, deepseek, galadriel, gemini, groq, huggingface, hyperbolic, mira,
9    mistral, moonshot, ollama, openai, openrouter, perplexity, together, xai,
10};
11use rig::streaming::StreamingPrompt;
12// 你自己的 provider
13use crate::extra_providers::bigmodel;
14
15/// 所有 Provider 的统一枚举
16#[derive(Clone)]
17pub enum AgentVariant {
18    OpenAI(Agent<openai::completion::CompletionModel>),
19    ResponsesOpenAI(Agent<openai::responses_api::ResponsesCompletionModel>),
20    Ollama(Agent<ollama::CompletionModel>),
21    Bigmodel(Agent<bigmodel::CompletionModel>),
22    OpenRouter(Agent<openrouter::completion::CompletionModel>),
23    Anthropic(Agent<anthropic::completion::CompletionModel>),
24    Cohere(Agent<cohere::completion::CompletionModel>),
25    Gemini(Agent<gemini::completion::CompletionModel>),
26    Huggingface(Agent<huggingface::completion::CompletionModel>),
27    Mistral(Agent<mistral::completion::CompletionModel>),
28    Together(Agent<together::completion::CompletionModel>),
29    XAI(Agent<xai::completion::CompletionModel>),
30    Azure(Agent<azure::CompletionModel>),
31    DeepSeek(Agent<deepseek::CompletionModel>),
32    Galadriel(Agent<galadriel::CompletionModel>),
33    Groq(Agent<groq::CompletionModel>),
34    Hyperbolic(Agent<hyperbolic::CompletionModel>),
35    Mira(Agent<mira::CompletionModel>),
36    Mooshot(Agent<moonshot::CompletionModel>),
37    Perplexity(Agent<perplexity::CompletionModel>),
38}
39
40/// 统一的返回流类型
41pub type AnyMultiStream<R> =
42    Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, PromptError>> + Send>>;
43
44/// 将 Provider 原始流转换成统一的 Event 流
45fn map_stream<SR, R, E>(
46    stream: Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<SR>, E>> + Send>>,
47) -> AnyMultiStream<R>
48where
49    SR: serde::Serialize + Send + 'static,
50    R: DeserializeOwned + Send + 'static,
51    E: std::error::Error + Send + Sync + 'static,
52{
53    Box::pin(stream.map(move |item| {
54        item.map_err(|e| {
55            // Convert error to string and create PromptError
56            let error_msg = format!("stream error: {}", e);
57            PromptError::MaxDepthError {
58                max_depth: 0,
59                chat_history: Box::new(vec![]),
60                prompt: error_msg.into(),
61            }
62        })
63        .and_then(|multi_item| {
64            // Convert MultiTurnStreamItem<SR> to MultiTurnStreamItem<R>
65            // by serializing and deserializing the whole item
66            serde_json::to_value(&multi_item)
67                .map_err(|e| PromptError::MaxDepthError {
68                    max_depth: 0,
69                    chat_history: Box::new(vec![]),
70                    prompt: format!("serialization error: {}", e).into(),
71                })
72                .and_then(|val| {
73                    serde_json::from_value(val).map_err(|e| PromptError::MaxDepthError {
74                        max_depth: 0,
75                        chat_history: Box::new(vec![]),
76                        prompt: format!("deserialization error: {}", e).into(),
77                    })
78                })
79        })
80    }))
81}
82
83// ======================
84// 同步 Prompt 调用
85// ======================
86impl AgentVariant {
87    pub async fn prompt(&self, prompt: impl Into<Message> + Send) -> Result<String, PromptError> {
88        match self {
89            AgentVariant::OpenAI(a) => a.prompt(prompt).await,
90            AgentVariant::ResponsesOpenAI(a) => a.prompt(prompt).await,
91            AgentVariant::Ollama(a) => a.prompt(prompt).await,
92            AgentVariant::Bigmodel(a) => a.prompt(prompt).await,
93            AgentVariant::OpenRouter(a) => a.prompt(prompt).await,
94            AgentVariant::Anthropic(a) => a.prompt(prompt).await,
95            AgentVariant::Cohere(a) => a.prompt(prompt).await,
96            AgentVariant::Gemini(a) => a.prompt(prompt).await,
97            AgentVariant::Huggingface(a) => a.prompt(prompt).await,
98            AgentVariant::Mistral(a) => a.prompt(prompt).await,
99            AgentVariant::Together(a) => a.prompt(prompt).await,
100            AgentVariant::XAI(a) => a.prompt(prompt).await,
101            AgentVariant::Azure(a) => a.prompt(prompt).await,
102            AgentVariant::DeepSeek(a) => a.prompt(prompt).await,
103            AgentVariant::Galadriel(a) => a.prompt(prompt).await,
104            AgentVariant::Groq(a) => a.prompt(prompt).await,
105            AgentVariant::Hyperbolic(a) => a.prompt(prompt).await,
106            AgentVariant::Mira(a) => a.prompt(prompt).await,
107            AgentVariant::Mooshot(a) => a.prompt(prompt).await,
108            AgentVariant::Perplexity(a) => a.prompt(prompt).await,
109        }
110    }
111}
112
113// ======================
114// Streaming Prompt 调用(重点)
115// ======================
116impl AgentVariant {
117    pub async fn stream_prompt<R>(
118        &self,
119        prompt: impl Into<Message> + Send,
120    ) -> Result<AnyMultiStream<R>, PromptError>
121    where
122        R: DeserializeOwned + Send + 'static,
123    {
124        async fn handle<M, R>(
125            agent: &Agent<M>,
126            prompt: impl Into<Message> + Send,
127        ) -> Result<AnyMultiStream<R>, PromptError>
128        where
129            M: rig::completion::CompletionModel + Send + Sync + 'static,
130            <M as rig::completion::CompletionModel>::StreamingResponse:
131                serde::Serialize + Send + 'static,
132            R: DeserializeOwned + Send + 'static,
133        {
134            // Provider 层的错误类型通常是 StreamingError
135            let raw_stream = agent.stream_prompt(prompt).await;
136
137            Ok(map_stream(raw_stream))
138        }
139
140        match self {
141            AgentVariant::OpenAI(a) => handle::<_, R>(a, prompt).await,
142            AgentVariant::ResponsesOpenAI(a) => handle::<_, R>(a, prompt).await,
143            AgentVariant::Ollama(a) => handle::<_, R>(a, prompt).await,
144            AgentVariant::Bigmodel(a) => handle::<_, R>(a, prompt).await,
145            AgentVariant::OpenRouter(a) => handle::<_, R>(a, prompt).await,
146            AgentVariant::Anthropic(a) => handle::<_, R>(a, prompt).await,
147            AgentVariant::Cohere(a) => handle::<_, R>(a, prompt).await,
148            AgentVariant::Gemini(a) => handle::<_, R>(a, prompt).await,
149            AgentVariant::Huggingface(a) => handle::<_, R>(a, prompt).await,
150            AgentVariant::Mistral(a) => handle::<_, R>(a, prompt).await,
151            AgentVariant::Together(a) => handle::<_, R>(a, prompt).await,
152            AgentVariant::XAI(a) => handle::<_, R>(a, prompt).await,
153            AgentVariant::Azure(a) => handle::<_, R>(a, prompt).await,
154            AgentVariant::DeepSeek(a) => handle::<_, R>(a, prompt).await,
155            AgentVariant::Galadriel(a) => handle::<_, R>(a, prompt).await,
156            AgentVariant::Groq(a) => handle::<_, R>(a, prompt).await,
157            AgentVariant::Hyperbolic(a) => handle::<_, R>(a, prompt).await,
158            AgentVariant::Mira(a) => handle::<_, R>(a, prompt).await,
159            AgentVariant::Mooshot(a) => handle::<_, R>(a, prompt).await,
160            AgentVariant::Perplexity(a) => handle::<_, R>(a, prompt).await,
161        }
162    }
163}