rig/providers/
mira.rs

1//! Mira API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::mira;
6//!
7//! let client = mira::Client::new("YOUR_API_KEY");
8//!
9//! ```
10use crate::client::{CompletionClient, ProviderClient};
11use crate::json_utils::merge;
12use crate::providers::openai;
13use crate::providers::openai::send_compatible_streaming_request;
14use crate::streaming::StreamingCompletionResponse;
15use crate::{
16    completion::{self, CompletionError, CompletionRequest},
17    impl_conversion_traits,
18    message::{self, AssistantContent, Message, UserContent},
19    OneOrMany,
20};
21use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE};
22use serde::Deserialize;
23use serde_json::{json, Value};
24use std::string::FromUtf8Error;
25use thiserror::Error;
26use tracing;
27
28#[derive(Debug, Error)]
29pub enum MiraError {
30    #[error("Invalid API key")]
31    InvalidApiKey,
32    #[error("API error: {0}")]
33    ApiError(u16),
34    #[error("Request error: {0}")]
35    RequestError(#[from] reqwest::Error),
36    #[error("UTF-8 error: {0}")]
37    Utf8Error(#[from] FromUtf8Error),
38    #[error("JSON error: {0}")]
39    JsonError(#[from] serde_json::Error),
40}
41
42#[derive(Debug, Deserialize)]
43struct ApiErrorResponse {
44    message: String,
45}
46
47#[derive(Debug, Deserialize, Clone)]
48pub struct RawMessage {
49    pub role: String,
50    pub content: String,
51}
52
53const MIRA_API_BASE_URL: &str = "https://api.mira.network";
54
55impl TryFrom<RawMessage> for message::Message {
56    type Error = CompletionError;
57
58    fn try_from(raw: RawMessage) -> Result<Self, Self::Error> {
59        match raw.role.as_str() {
60            "user" => Ok(message::Message::User {
61                content: OneOrMany::one(UserContent::Text(message::Text { text: raw.content })),
62            }),
63            "assistant" => Ok(message::Message::Assistant {
64                content: OneOrMany::one(AssistantContent::Text(message::Text {
65                    text: raw.content,
66                })),
67            }),
68            _ => Err(CompletionError::ResponseError(format!(
69                "Unsupported message role: {}",
70                raw.role
71            ))),
72        }
73    }
74}
75
76#[derive(Debug, Deserialize)]
77#[serde(untagged)]
78pub enum CompletionResponse {
79    Structured {
80        id: String,
81        object: String,
82        created: u64,
83        model: String,
84        choices: Vec<ChatChoice>,
85        #[serde(skip_serializing_if = "Option::is_none")]
86        usage: Option<Usage>,
87    },
88    Simple(String),
89}
90
91#[derive(Debug, Deserialize)]
92pub struct ChatChoice {
93    pub message: RawMessage,
94    #[serde(default)]
95    pub finish_reason: Option<String>,
96    #[serde(default)]
97    pub index: Option<usize>,
98}
99
100#[derive(Debug, Deserialize)]
101struct ModelsResponse {
102    data: Vec<ModelInfo>,
103}
104
105#[derive(Debug, Deserialize)]
106struct ModelInfo {
107    id: String,
108}
109
110#[derive(Clone, Debug)]
111/// Client for interacting with the Mira API
112pub struct Client {
113    base_url: String,
114    client: reqwest::Client,
115    headers: HeaderMap,
116}
117
118impl Client {
119    /// Create a new Mira client with the given API key
120    pub fn new(api_key: &str) -> Result<Self, MiraError> {
121        let mut headers = HeaderMap::new();
122        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
123        headers.insert(
124            AUTHORIZATION,
125            HeaderValue::from_str(&format!("Bearer {api_key}"))
126                .map_err(|_| MiraError::InvalidApiKey)?,
127        );
128        headers.insert(
129            reqwest::header::ACCEPT,
130            HeaderValue::from_static("application/json"),
131        );
132        headers.insert(
133            reqwest::header::USER_AGENT,
134            HeaderValue::from_static("rig-client/1.0"),
135        );
136
137        Ok(Self {
138            base_url: MIRA_API_BASE_URL.to_string(),
139            client: reqwest::Client::builder()
140                .build()
141                .expect("Failed to build HTTP client"),
142            headers,
143        })
144    }
145
146    /// Create a new Mira client with a custom base URL and API key
147    pub fn new_with_base_url(
148        api_key: &str,
149        base_url: impl Into<String>,
150    ) -> Result<Self, MiraError> {
151        let mut client = Self::new(api_key)?;
152        client.base_url = base_url.into();
153        Ok(client)
154    }
155
156    /// List available models
157    pub async fn list_models(&self) -> Result<Vec<String>, MiraError> {
158        let url = format!("{}/v1/models", self.base_url);
159
160        let response = self
161            .client
162            .get(&url)
163            .headers(self.headers.clone())
164            .send()
165            .await?;
166
167        let status = response.status();
168
169        if !status.is_success() {
170            // Log the error text but don't store it in an unused variable
171            let _error_text = response.text().await.unwrap_or_default();
172            tracing::error!("Error response: {}", _error_text);
173            return Err(MiraError::ApiError(status.as_u16()));
174        }
175
176        let response_text = response.text().await?;
177
178        let models: ModelsResponse = serde_json::from_str(&response_text).map_err(|e| {
179            tracing::error!("Failed to parse response: {}", e);
180            MiraError::JsonError(e)
181        })?;
182
183        Ok(models.data.into_iter().map(|model| model.id).collect())
184    }
185}
186
187impl ProviderClient for Client {
188    /// Create a new Mira client from the `MIRA_API_KEY` environment variable.
189    /// Panics if the environment variable is not set.
190    fn from_env() -> Self {
191        let api_key = std::env::var("MIRA_API_KEY").expect("MIRA_API_KEY not set");
192        Self::new(&api_key).expect("Could not create Mira Client")
193    }
194}
195
196impl CompletionClient for Client {
197    type CompletionModel = CompletionModel;
198    /// Create a completion model with the given name.
199    fn completion_model(&self, model: &str) -> CompletionModel {
200        CompletionModel::new(self.to_owned(), model)
201    }
202}
203
204impl_conversion_traits!(
205    AsEmbeddings,
206    AsTranscription,
207    AsImageGeneration,
208    AsAudioGeneration for Client
209);
210
211#[derive(Clone)]
212pub struct CompletionModel {
213    client: Client,
214    /// Name of the model
215    pub model: String,
216}
217
218impl CompletionModel {
219    pub fn new(client: Client, model: &str) -> Self {
220        Self {
221            client,
222            model: model.to_string(),
223        }
224    }
225
226    fn create_completion_request(
227        &self,
228        completion_request: CompletionRequest,
229    ) -> Result<Value, CompletionError> {
230        let mut messages = Vec::new();
231
232        // Add preamble as user message if available
233        if let Some(preamble) = &completion_request.preamble {
234            messages.push(serde_json::json!({
235                "role": "user",
236                "content": preamble.to_string()
237            }));
238        }
239
240        // Add docs
241        if let Some(Message::User { content }) = completion_request.normalized_documents() {
242            let text = content
243                .into_iter()
244                .filter_map(|doc| match doc {
245                    UserContent::Document(doc) => Some(doc.data),
246                    UserContent::Text(text) => Some(text.text),
247
248                    // This should always be `Document`
249                    _ => None,
250                })
251                .collect::<Vec<_>>()
252                .join("\n");
253
254            messages.push(serde_json::json!({
255                "role": "user",
256                "content": text
257            }));
258        }
259
260        // Add chat history
261        for msg in completion_request.chat_history {
262            let (role, content) = match msg {
263                Message::User { content } => {
264                    let text = content
265                        .iter()
266                        .map(|c| match c {
267                            UserContent::Text(text) => &text.text,
268                            _ => "",
269                        })
270                        .collect::<Vec<_>>()
271                        .join("\n");
272                    ("user", text)
273                }
274                Message::Assistant { content } => {
275                    let text = content
276                        .iter()
277                        .map(|c| match c {
278                            AssistantContent::Text(text) => &text.text,
279                            _ => "",
280                        })
281                        .collect::<Vec<_>>()
282                        .join("\n");
283                    ("assistant", text)
284                }
285            };
286            messages.push(serde_json::json!({
287                "role": role,
288                "content": content
289            }));
290        }
291
292        let request = serde_json::json!({
293            "model": self.model,
294            "messages": messages,
295            "temperature": completion_request.temperature.map(|t| t as f32).unwrap_or(0.7),
296            "max_tokens": completion_request.max_tokens.map(|t| t as u32).unwrap_or(100),
297            "stream": false
298        });
299
300        Ok(request)
301    }
302}
303
304impl completion::CompletionModel for CompletionModel {
305    type Response = CompletionResponse;
306    type StreamingResponse = openai::StreamingCompletionResponse;
307
308    #[cfg_attr(feature = "worker", worker::send)]
309    async fn completion(
310        &self,
311        completion_request: CompletionRequest,
312    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
313        if !completion_request.tools.is_empty() {
314            tracing::warn!(target: "rig",
315                "Tool calls are not supported by the Mira provider. {} tools will be ignored.",
316                completion_request.tools.len()
317            );
318        }
319
320        let mira_request = self.create_completion_request(completion_request)?;
321
322        let response = self
323            .client
324            .client
325            .post(format!("{}/v1/chat/completions", self.client.base_url))
326            .headers(self.client.headers.clone())
327            .json(&mira_request)
328            .send()
329            .await
330            .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
331
332        if !response.status().is_success() {
333            let status = response.status().as_u16();
334            let error_text = response.text().await.unwrap_or_default();
335            return Err(CompletionError::ProviderError(format!(
336                "API error: {status} - {error_text}"
337            )));
338        }
339
340        let response: CompletionResponse = response
341            .json()
342            .await
343            .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
344
345        response.try_into()
346    }
347
348    #[cfg_attr(feature = "worker", worker::send)]
349    async fn stream(
350        &self,
351        completion_request: CompletionRequest,
352    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
353        let mut request = self.create_completion_request(completion_request)?;
354
355        request = merge(request, json!({"stream": true}));
356
357        let builder = self
358            .client
359            .client
360            .post(format!("{}/v1/chat/completions", self.client.base_url))
361            .headers(self.client.headers.clone())
362            .json(&request);
363
364        send_compatible_streaming_request(builder).await
365    }
366}
367
368impl From<ApiErrorResponse> for CompletionError {
369    fn from(err: ApiErrorResponse) -> Self {
370        CompletionError::ProviderError(err.message)
371    }
372}
373
374impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
375    type Error = CompletionError;
376
377    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
378        let content = match &response {
379            CompletionResponse::Structured { choices, .. } => {
380                let choice = choices.first().ok_or_else(|| {
381                    CompletionError::ResponseError("Response contained no choices".to_owned())
382                })?;
383
384                // Convert RawMessage to message::Message
385                let message = message::Message::try_from(choice.message.clone())?;
386
387                match message {
388                    Message::Assistant { content } => {
389                        if content.is_empty() {
390                            return Err(CompletionError::ResponseError(
391                                "Response contained empty content".to_owned(),
392                            ));
393                        }
394
395                        // Log warning for unsupported content types
396                        for c in content.iter() {
397                            if !matches!(c, AssistantContent::Text(_)) {
398                                tracing::warn!(target: "rig",
399                                    "Unsupported content type encountered: {:?}. The Mira provider currently only supports text content", c
400                                );
401                            }
402                        }
403
404                        content.iter().map(|c| {
405                            match c {
406                                AssistantContent::Text(text) => Ok(completion::AssistantContent::text(&text.text)),
407                                other => Err(CompletionError::ResponseError(
408                                    format!("Unsupported content type: {other:?}. The Mira provider currently only supports text content")
409                                ))
410                            }
411                        }).collect::<Result<Vec<_>, _>>()?
412                    }
413                    Message::User { .. } => {
414                        tracing::warn!(target: "rig", "Received user message in response where assistant message was expected");
415                        return Err(CompletionError::ResponseError(
416                            "Received user message in response where assistant message was expected".to_owned()
417                        ));
418                    }
419                }
420            }
421            CompletionResponse::Simple(text) => {
422                vec![completion::AssistantContent::text(text)]
423            }
424        };
425
426        let choice = OneOrMany::many(content).map_err(|_| {
427            CompletionError::ResponseError(
428                "Response contained no message or tool call (empty)".to_owned(),
429            )
430        })?;
431
432        Ok(completion::CompletionResponse {
433            choice,
434            raw_response: response,
435        })
436    }
437}
438
439#[derive(Clone, Debug, Deserialize)]
440pub struct Usage {
441    pub prompt_tokens: usize,
442    pub total_tokens: usize,
443}
444
445impl std::fmt::Display for Usage {
446    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
447        write!(
448            f,
449            "Prompt tokens: {} Total tokens: {}",
450            self.prompt_tokens, self.total_tokens
451        )
452    }
453}
454
455impl From<Message> for serde_json::Value {
456    fn from(msg: Message) -> Self {
457        match msg {
458            Message::User { content } => {
459                let text = content
460                    .iter()
461                    .map(|c| match c {
462                        UserContent::Text(text) => &text.text,
463                        _ => "",
464                    })
465                    .collect::<Vec<_>>()
466                    .join("\n");
467                serde_json::json!({
468                    "role": "user",
469                    "content": text
470                })
471            }
472            Message::Assistant { content } => {
473                let text = content
474                    .iter()
475                    .map(|c| match c {
476                        AssistantContent::Text(text) => &text.text,
477                        _ => "",
478                    })
479                    .collect::<Vec<_>>()
480                    .join("\n");
481                serde_json::json!({
482                    "role": "assistant",
483                    "content": text
484                })
485            }
486        }
487    }
488}
489
490impl TryFrom<serde_json::Value> for Message {
491    type Error = CompletionError;
492
493    fn try_from(value: serde_json::Value) -> Result<Self, Self::Error> {
494        let role = value["role"].as_str().ok_or_else(|| {
495            CompletionError::ResponseError("Message missing role field".to_owned())
496        })?;
497
498        // Handle both string and array content formats
499        let content = match value.get("content") {
500            Some(content) => match content {
501                serde_json::Value::String(s) => s.clone(),
502                serde_json::Value::Array(arr) => arr
503                    .iter()
504                    .filter_map(|c| {
505                        c.get("text")
506                            .and_then(|t| t.as_str())
507                            .map(|text| text.to_string())
508                    })
509                    .collect::<Vec<_>>()
510                    .join("\n"),
511                _ => {
512                    return Err(CompletionError::ResponseError(
513                        "Message content must be string or array".to_owned(),
514                    ))
515                }
516            },
517            None => {
518                return Err(CompletionError::ResponseError(
519                    "Message missing content field".to_owned(),
520                ))
521            }
522        };
523
524        match role {
525            "user" => Ok(Message::User {
526                content: OneOrMany::one(UserContent::Text(message::Text { text: content })),
527            }),
528            "assistant" => Ok(Message::Assistant {
529                content: OneOrMany::one(AssistantContent::Text(message::Text { text: content })),
530            }),
531            _ => Err(CompletionError::ResponseError(format!(
532                "Unsupported message role: {role}"
533            ))),
534        }
535    }
536}
537
538#[cfg(test)]
539mod tests {
540    use super::*;
541    use crate::message::UserContent;
542    use serde_json::json;
543
544    #[test]
545    fn test_deserialize_message() {
546        // Test string content format
547        let assistant_message_json = json!({
548            "role": "assistant",
549            "content": "Hello there, how may I assist you today?"
550        });
551
552        let user_message_json = json!({
553            "role": "user",
554            "content": "What can you help me with?"
555        });
556
557        // Test array content format
558        let assistant_message_array_json = json!({
559            "role": "assistant",
560            "content": [{
561                "type": "text",
562                "text": "Hello there, how may I assist you today?"
563            }]
564        });
565
566        let assistant_message = Message::try_from(assistant_message_json).unwrap();
567        let user_message = Message::try_from(user_message_json).unwrap();
568        let assistant_message_array = Message::try_from(assistant_message_array_json).unwrap();
569
570        // Test string content format
571        match assistant_message {
572            Message::Assistant { content } => {
573                assert_eq!(
574                    content.first(),
575                    AssistantContent::Text(message::Text {
576                        text: "Hello there, how may I assist you today?".to_string()
577                    })
578                );
579            }
580            _ => panic!("Expected assistant message"),
581        }
582
583        match user_message {
584            Message::User { content } => {
585                assert_eq!(
586                    content.first(),
587                    UserContent::Text(message::Text {
588                        text: "What can you help me with?".to_string()
589                    })
590                );
591            }
592            _ => panic!("Expected user message"),
593        }
594
595        // Test array content format
596        match assistant_message_array {
597            Message::Assistant { content } => {
598                assert_eq!(
599                    content.first(),
600                    AssistantContent::Text(message::Text {
601                        text: "Hello there, how may I assist you today?".to_string()
602                    })
603                );
604            }
605            _ => panic!("Expected assistant message"),
606        }
607    }
608
609    #[test]
610    fn test_message_conversion() {
611        // Test converting from our Message type to Mira's format and back
612        let original_message = message::Message::User {
613            content: OneOrMany::one(message::UserContent::text("Hello")),
614        };
615
616        // Convert to Mira format
617        let mira_value: serde_json::Value = original_message.clone().try_into().unwrap();
618
619        // Convert back to our Message type
620        let converted_message: Message = mira_value.try_into().unwrap();
621
622        // Convert back to original format
623        let final_message: message::Message = converted_message.try_into().unwrap();
624
625        assert_eq!(original_message, final_message);
626    }
627
628    #[test]
629    fn test_completion_response_conversion() {
630        let mira_response = CompletionResponse::Structured {
631            id: "resp_123".to_string(),
632            object: "chat.completion".to_string(),
633            created: 1234567890,
634            model: "deepseek-r1".to_string(),
635            choices: vec![ChatChoice {
636                message: RawMessage {
637                    role: "assistant".to_string(),
638                    content: "Test response".to_string(),
639                },
640                finish_reason: Some("stop".to_string()),
641                index: Some(0),
642            }],
643            usage: Some(Usage {
644                prompt_tokens: 10,
645                total_tokens: 20,
646            }),
647        };
648
649        let completion_response: completion::CompletionResponse<CompletionResponse> =
650            mira_response.try_into().unwrap();
651
652        assert_eq!(
653            completion_response.choice.first(),
654            completion::AssistantContent::text("Test response")
655        );
656    }
657}