rig/providers/
mira.rs

1//! Mira API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::mira;
6//!
7//! let client = mira::Client::new("YOUR_API_KEY");
8//!
9//! ```
10use crate::json_utils::merge;
11use crate::providers::openai::send_compatible_streaming_request;
12use crate::streaming::{StreamingCompletionModel, StreamingResult};
13use crate::{
14    agent::AgentBuilder,
15    completion::{self, CompletionError, CompletionRequest},
16    extractor::ExtractorBuilder,
17    message::{self, AssistantContent, Message, UserContent},
18    OneOrMany,
19};
20use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE};
21use schemars::JsonSchema;
22use serde::{Deserialize, Serialize};
23use serde_json::{json, Value};
24use std::string::FromUtf8Error;
25use thiserror::Error;
26use tracing;
27
28#[derive(Debug, Error)]
29pub enum MiraError {
30    #[error("Invalid API key")]
31    InvalidApiKey,
32    #[error("API error: {0}")]
33    ApiError(u16),
34    #[error("Request error: {0}")]
35    RequestError(#[from] reqwest::Error),
36    #[error("UTF-8 error: {0}")]
37    Utf8Error(#[from] FromUtf8Error),
38    #[error("JSON error: {0}")]
39    JsonError(#[from] serde_json::Error),
40}
41
42#[derive(Debug, Deserialize)]
43struct ApiErrorResponse {
44    message: String,
45}
46
47#[derive(Debug, Deserialize, Clone)]
48pub struct RawMessage {
49    pub role: String,
50    pub content: String,
51}
52
53const MIRA_API_BASE_URL: &str = "https://api.mira.network";
54
55impl TryFrom<RawMessage> for message::Message {
56    type Error = CompletionError;
57
58    fn try_from(raw: RawMessage) -> Result<Self, Self::Error> {
59        match raw.role.as_str() {
60            "user" => Ok(message::Message::User {
61                content: OneOrMany::one(UserContent::Text(message::Text { text: raw.content })),
62            }),
63            "assistant" => Ok(message::Message::Assistant {
64                content: OneOrMany::one(AssistantContent::Text(message::Text {
65                    text: raw.content,
66                })),
67            }),
68            _ => Err(CompletionError::ResponseError(format!(
69                "Unsupported message role: {}",
70                raw.role
71            ))),
72        }
73    }
74}
75
76#[derive(Debug, Deserialize)]
77#[serde(untagged)]
78pub enum CompletionResponse {
79    Structured {
80        id: String,
81        object: String,
82        created: u64,
83        model: String,
84        choices: Vec<ChatChoice>,
85        #[serde(skip_serializing_if = "Option::is_none")]
86        usage: Option<Usage>,
87    },
88    Simple(String),
89}
90
91#[derive(Debug, Deserialize)]
92pub struct ChatChoice {
93    pub message: RawMessage,
94    #[serde(default)]
95    pub finish_reason: Option<String>,
96    #[serde(default)]
97    pub index: Option<usize>,
98}
99
100#[derive(Debug, Deserialize)]
101struct ModelsResponse {
102    data: Vec<ModelInfo>,
103}
104
105#[derive(Debug, Deserialize)]
106struct ModelInfo {
107    id: String,
108}
109
110#[derive(Clone)]
111/// Client for interacting with the Mira API
112pub struct Client {
113    base_url: String,
114    client: reqwest::Client,
115    headers: HeaderMap,
116}
117
118impl Client {
119    /// Create a new Mira client with the given API key
120    pub fn new(api_key: &str) -> Result<Self, MiraError> {
121        let mut headers = HeaderMap::new();
122        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
123        headers.insert(
124            AUTHORIZATION,
125            HeaderValue::from_str(&format!("Bearer {}", api_key))
126                .map_err(|_| MiraError::InvalidApiKey)?,
127        );
128        headers.insert(
129            reqwest::header::ACCEPT,
130            HeaderValue::from_static("application/json"),
131        );
132        headers.insert(
133            reqwest::header::USER_AGENT,
134            HeaderValue::from_static("rig-client/1.0"),
135        );
136
137        Ok(Self {
138            base_url: MIRA_API_BASE_URL.to_string(),
139            client: reqwest::Client::builder()
140                .build()
141                .expect("Failed to build HTTP client"),
142            headers,
143        })
144    }
145
146    /// Create a new Mira client from the `MIRA_API_KEY` environment variable.
147    /// Panics if the environment variable is not set.
148    pub fn from_env() -> Result<Self, MiraError> {
149        let api_key = std::env::var("MIRA_API_KEY").expect("MIRA_API_KEY not set");
150        Self::new(&api_key)
151    }
152
153    /// Create a new Mira client with a custom base URL and API key
154    pub fn new_with_base_url(
155        api_key: &str,
156        base_url: impl Into<String>,
157    ) -> Result<Self, MiraError> {
158        let mut client = Self::new(api_key)?;
159        client.base_url = base_url.into();
160        Ok(client)
161    }
162
163    /// List available models
164    pub async fn list_models(&self) -> Result<Vec<String>, MiraError> {
165        let url = format!("{}/v1/models", self.base_url);
166
167        let response = self
168            .client
169            .get(&url)
170            .headers(self.headers.clone())
171            .send()
172            .await?;
173
174        let status = response.status();
175
176        if !status.is_success() {
177            // Log the error text but don't store it in an unused variable
178            let _error_text = response.text().await.unwrap_or_default();
179            tracing::error!("Error response: {}", _error_text);
180            return Err(MiraError::ApiError(status.as_u16()));
181        }
182
183        let response_text = response.text().await?;
184
185        let models: ModelsResponse = serde_json::from_str(&response_text).map_err(|e| {
186            tracing::error!("Failed to parse response: {}", e);
187            MiraError::JsonError(e)
188        })?;
189
190        Ok(models.data.into_iter().map(|model| model.id).collect())
191    }
192
193    /// Create a completion model with the given name.
194    pub fn completion_model(&self, model: &str) -> CompletionModel {
195        CompletionModel::new(self.to_owned(), model)
196    }
197
198    /// Create an agent builder with the given completion model.
199    pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
200        AgentBuilder::new(self.completion_model(model))
201    }
202
203    /// Create an extractor builder with the given completion model.
204    pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
205        &self,
206        model: &str,
207    ) -> ExtractorBuilder<T, CompletionModel> {
208        ExtractorBuilder::new(self.completion_model(model))
209    }
210}
211
212#[derive(Clone)]
213pub struct CompletionModel {
214    client: Client,
215    /// Name of the model
216    pub model: String,
217}
218
219impl CompletionModel {
220    pub fn new(client: Client, model: &str) -> Self {
221        Self {
222            client,
223            model: model.to_string(),
224        }
225    }
226
227    fn create_completion_request(
228        &self,
229        completion_request: CompletionRequest,
230    ) -> Result<Value, CompletionError> {
231        let mut messages = Vec::new();
232
233        // Add preamble as user message if available
234        if let Some(preamble) = &completion_request.preamble {
235            messages.push(serde_json::json!({
236                "role": "user",
237                "content": preamble.to_string()
238            }));
239        }
240
241        // Add prompt
242        messages.push(match &completion_request.prompt {
243            Message::User { content } => {
244                let text = content
245                    .iter()
246                    .map(|c| match c {
247                        UserContent::Text(text) => &text.text,
248                        _ => "",
249                    })
250                    .collect::<Vec<_>>()
251                    .join("\n");
252                serde_json::json!({
253                    "role": "user",
254                    "content": text
255                })
256            }
257            _ => unreachable!(),
258        });
259
260        // Add chat history
261        for msg in completion_request.chat_history {
262            let (role, content) = match msg {
263                Message::User { content } => {
264                    let text = content
265                        .iter()
266                        .map(|c| match c {
267                            UserContent::Text(text) => &text.text,
268                            _ => "",
269                        })
270                        .collect::<Vec<_>>()
271                        .join("\n");
272                    ("user", text)
273                }
274                Message::Assistant { content } => {
275                    let text = content
276                        .iter()
277                        .map(|c| match c {
278                            AssistantContent::Text(text) => &text.text,
279                            _ => "",
280                        })
281                        .collect::<Vec<_>>()
282                        .join("\n");
283                    ("assistant", text)
284                }
285            };
286            messages.push(serde_json::json!({
287                "role": role,
288                "content": content
289            }));
290        }
291
292        let request = serde_json::json!({
293            "model": self.model,
294            "messages": messages,
295            "temperature": completion_request.temperature.map(|t| t as f32).unwrap_or(0.7),
296            "max_tokens": completion_request.max_tokens.map(|t| t as u32).unwrap_or(100),
297            "stream": false
298        });
299
300        Ok(request)
301    }
302}
303
304impl completion::CompletionModel for CompletionModel {
305    type Response = CompletionResponse;
306
307    #[cfg_attr(feature = "worker", worker::send)]
308    async fn completion(
309        &self,
310        completion_request: CompletionRequest,
311    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
312        if !completion_request.tools.is_empty() {
313            tracing::warn!(target: "rig",
314                "Tool calls are not supported by the Mira provider. {} tools will be ignored.",
315                completion_request.tools.len()
316            );
317        }
318
319        let mira_request = self.create_completion_request(completion_request)?;
320
321        let response = self
322            .client
323            .client
324            .post(format!("{}/v1/chat/completions", self.client.base_url))
325            .headers(self.client.headers.clone())
326            .json(&mira_request)
327            .send()
328            .await
329            .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
330
331        if !response.status().is_success() {
332            let status = response.status().as_u16();
333            let error_text = response.text().await.unwrap_or_default();
334            return Err(CompletionError::ProviderError(format!(
335                "API error: {} - {}",
336                status, error_text
337            )));
338        }
339
340        let response: CompletionResponse = response
341            .json()
342            .await
343            .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
344
345        response.try_into()
346    }
347}
348
349impl StreamingCompletionModel for CompletionModel {
350    async fn stream(
351        &self,
352        completion_request: CompletionRequest,
353    ) -> Result<StreamingResult, CompletionError> {
354        let mut request = self.create_completion_request(completion_request)?;
355
356        request = merge(request, json!({"stream": true}));
357
358        let builder = self
359            .client
360            .client
361            .post(format!("{}/v1/chat/completions", self.client.base_url))
362            .headers(self.client.headers.clone())
363            .json(&request);
364
365        send_compatible_streaming_request(builder).await
366    }
367}
368
369impl From<ApiErrorResponse> for CompletionError {
370    fn from(err: ApiErrorResponse) -> Self {
371        CompletionError::ProviderError(err.message)
372    }
373}
374
375impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
376    type Error = CompletionError;
377
378    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
379        let content = match &response {
380            CompletionResponse::Structured { choices, .. } => {
381                let choice = choices.first().ok_or_else(|| {
382                    CompletionError::ResponseError("Response contained no choices".to_owned())
383                })?;
384
385                // Convert RawMessage to message::Message
386                let message = message::Message::try_from(choice.message.clone())?;
387
388                match message {
389                    Message::Assistant { content } => {
390                        if content.is_empty() {
391                            return Err(CompletionError::ResponseError(
392                                "Response contained empty content".to_owned(),
393                            ));
394                        }
395
396                        // Log warning for unsupported content types
397                        for c in content.iter() {
398                            if !matches!(c, AssistantContent::Text(_)) {
399                                tracing::warn!(target: "rig",
400                                    "Unsupported content type encountered: {:?}. The Mira provider currently only supports text content", c
401                                );
402                            }
403                        }
404
405                        content.iter().map(|c| {
406                            match c {
407                                AssistantContent::Text(text) => Ok(completion::AssistantContent::text(&text.text)),
408                                other => Err(CompletionError::ResponseError(
409                                    format!("Unsupported content type: {:?}. The Mira provider currently only supports text content", other)
410                                ))
411                            }
412                        }).collect::<Result<Vec<_>, _>>()?
413                    }
414                    Message::User { .. } => {
415                        tracing::warn!(target: "rig", "Received user message in response where assistant message was expected");
416                        return Err(CompletionError::ResponseError(
417                            "Received user message in response where assistant message was expected".to_owned()
418                        ));
419                    }
420                }
421            }
422            CompletionResponse::Simple(text) => {
423                vec![completion::AssistantContent::text(text)]
424            }
425        };
426
427        let choice = OneOrMany::many(content).map_err(|_| {
428            CompletionError::ResponseError(
429                "Response contained no message or tool call (empty)".to_owned(),
430            )
431        })?;
432
433        Ok(completion::CompletionResponse {
434            choice,
435            raw_response: response,
436        })
437    }
438}
439
440#[derive(Clone, Debug, Deserialize)]
441pub struct Usage {
442    pub prompt_tokens: usize,
443    pub total_tokens: usize,
444}
445
446impl std::fmt::Display for Usage {
447    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
448        write!(
449            f,
450            "Prompt tokens: {} Total tokens: {}",
451            self.prompt_tokens, self.total_tokens
452        )
453    }
454}
455
456impl From<Message> for serde_json::Value {
457    fn from(msg: Message) -> Self {
458        match msg {
459            Message::User { content } => {
460                let text = content
461                    .iter()
462                    .map(|c| match c {
463                        UserContent::Text(text) => &text.text,
464                        _ => "",
465                    })
466                    .collect::<Vec<_>>()
467                    .join("\n");
468                serde_json::json!({
469                    "role": "user",
470                    "content": text
471                })
472            }
473            Message::Assistant { content } => {
474                let text = content
475                    .iter()
476                    .map(|c| match c {
477                        AssistantContent::Text(text) => &text.text,
478                        _ => "",
479                    })
480                    .collect::<Vec<_>>()
481                    .join("\n");
482                serde_json::json!({
483                    "role": "assistant",
484                    "content": text
485                })
486            }
487        }
488    }
489}
490
491impl TryFrom<serde_json::Value> for Message {
492    type Error = CompletionError;
493
494    fn try_from(value: serde_json::Value) -> Result<Self, Self::Error> {
495        let role = value["role"].as_str().ok_or_else(|| {
496            CompletionError::ResponseError("Message missing role field".to_owned())
497        })?;
498
499        // Handle both string and array content formats
500        let content = match value.get("content") {
501            Some(content) => match content {
502                serde_json::Value::String(s) => s.clone(),
503                serde_json::Value::Array(arr) => arr
504                    .iter()
505                    .filter_map(|c| {
506                        c.get("text")
507                            .and_then(|t| t.as_str())
508                            .map(|text| text.to_string())
509                    })
510                    .collect::<Vec<_>>()
511                    .join("\n"),
512                _ => {
513                    return Err(CompletionError::ResponseError(
514                        "Message content must be string or array".to_owned(),
515                    ))
516                }
517            },
518            None => {
519                return Err(CompletionError::ResponseError(
520                    "Message missing content field".to_owned(),
521                ))
522            }
523        };
524
525        match role {
526            "user" => Ok(Message::User {
527                content: OneOrMany::one(UserContent::Text(message::Text { text: content })),
528            }),
529            "assistant" => Ok(Message::Assistant {
530                content: OneOrMany::one(AssistantContent::Text(message::Text { text: content })),
531            }),
532            _ => Err(CompletionError::ResponseError(format!(
533                "Unsupported message role: {}",
534                role
535            ))),
536        }
537    }
538}
539
540#[cfg(test)]
541mod tests {
542    use super::*;
543    use crate::message::UserContent;
544    use serde_json::json;
545
546    #[test]
547    fn test_deserialize_message() {
548        // Test string content format
549        let assistant_message_json = json!({
550            "role": "assistant",
551            "content": "Hello there, how may I assist you today?"
552        });
553
554        let user_message_json = json!({
555            "role": "user",
556            "content": "What can you help me with?"
557        });
558
559        // Test array content format
560        let assistant_message_array_json = json!({
561            "role": "assistant",
562            "content": [{
563                "type": "text",
564                "text": "Hello there, how may I assist you today?"
565            }]
566        });
567
568        let assistant_message = Message::try_from(assistant_message_json).unwrap();
569        let user_message = Message::try_from(user_message_json).unwrap();
570        let assistant_message_array = Message::try_from(assistant_message_array_json).unwrap();
571
572        // Test string content format
573        match assistant_message {
574            Message::Assistant { content } => {
575                assert_eq!(
576                    content.first(),
577                    AssistantContent::Text(message::Text {
578                        text: "Hello there, how may I assist you today?".to_string()
579                    })
580                );
581            }
582            _ => panic!("Expected assistant message"),
583        }
584
585        match user_message {
586            Message::User { content } => {
587                assert_eq!(
588                    content.first(),
589                    UserContent::Text(message::Text {
590                        text: "What can you help me with?".to_string()
591                    })
592                );
593            }
594            _ => panic!("Expected user message"),
595        }
596
597        // Test array content format
598        match assistant_message_array {
599            Message::Assistant { content } => {
600                assert_eq!(
601                    content.first(),
602                    AssistantContent::Text(message::Text {
603                        text: "Hello there, how may I assist you today?".to_string()
604                    })
605                );
606            }
607            _ => panic!("Expected assistant message"),
608        }
609    }
610
611    #[test]
612    fn test_message_conversion() {
613        // Test converting from our Message type to Mira's format and back
614        let original_message = message::Message::User {
615            content: OneOrMany::one(message::UserContent::text("Hello")),
616        };
617
618        // Convert to Mira format
619        let mira_value: serde_json::Value = original_message.clone().try_into().unwrap();
620
621        // Convert back to our Message type
622        let converted_message: Message = mira_value.try_into().unwrap();
623
624        // Convert back to original format
625        let final_message: message::Message = converted_message.try_into().unwrap();
626
627        assert_eq!(original_message, final_message);
628    }
629
630    #[test]
631    fn test_completion_response_conversion() {
632        let mira_response = CompletionResponse::Structured {
633            id: "resp_123".to_string(),
634            object: "chat.completion".to_string(),
635            created: 1234567890,
636            model: "deepseek-r1".to_string(),
637            choices: vec![ChatChoice {
638                message: RawMessage {
639                    role: "assistant".to_string(),
640                    content: "Test response".to_string(),
641                },
642                finish_reason: Some("stop".to_string()),
643                index: Some(0),
644            }],
645            usage: Some(Usage {
646                prompt_tokens: 10,
647                total_tokens: 20,
648            }),
649        };
650
651        let completion_response: completion::CompletionResponse<CompletionResponse> =
652            mira_response.try_into().unwrap();
653
654        assert_eq!(
655            completion_response.choice.first(),
656            completion::AssistantContent::text("Test response")
657        );
658    }
659}