rig/providers/
mira.rs

1//! Mira API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::mira;
6//!
7//! let client = mira::Client::new("YOUR_API_KEY");
8//!
9//! ```
10use crate::client::{
11    self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
12    ProviderClient,
13};
14use crate::http_client::{self, HttpClientExt};
15use crate::message::{Document, DocumentSourceKind};
16use crate::providers::openai;
17use crate::providers::openai::send_compatible_streaming_request;
18use crate::streaming::StreamingCompletionResponse;
19use crate::{
20    OneOrMany,
21    completion::{self, CompletionError, CompletionRequest},
22    message::{self, AssistantContent, Message, UserContent},
23};
24use serde::{Deserialize, Serialize};
25use std::string::FromUtf8Error;
26use thiserror::Error;
27use tracing::{self, Instrument, info_span};
28
29#[derive(Debug, Default, Clone, Copy)]
30pub struct MiraExt;
31#[derive(Debug, Default, Clone, Copy)]
32pub struct MiraBuilder;
33
34type MiraApiKey = BearerAuth;
35
36impl Provider for MiraExt {
37    type Builder = MiraBuilder;
38
39    const VERIFY_PATH: &'static str = "/user-credits";
40
41    fn build<H>(
42        _: &crate::client::ClientBuilder<
43            Self::Builder,
44            <Self::Builder as crate::client::ProviderBuilder>::ApiKey,
45            H,
46        >,
47    ) -> http_client::Result<Self> {
48        Ok(Self)
49    }
50}
51
52impl<H> Capabilities<H> for MiraExt {
53    type Completion = Capable<CompletionModel<H>>;
54    type Embeddings = Nothing;
55    type Transcription = Nothing;
56
57    #[cfg(feature = "image")]
58    type ImageGeneration = Nothing;
59
60    #[cfg(feature = "audio")]
61    type AudioGeneration = Nothing;
62}
63
64impl DebugExt for MiraExt {}
65
66impl ProviderBuilder for MiraBuilder {
67    type Output = MiraExt;
68    type ApiKey = MiraApiKey;
69
70    const BASE_URL: &'static str = MIRA_API_BASE_URL;
71}
72
73pub type Client<H = reqwest::Client> = client::Client<MiraExt, H>;
74pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<MiraBuilder, MiraApiKey, H>;
75
76#[derive(Debug, Error)]
77pub enum MiraError {
78    #[error("Invalid API key")]
79    InvalidApiKey,
80    #[error("API error: {0}")]
81    ApiError(u16),
82    #[error("Request error: {0}")]
83    RequestError(#[from] http_client::Error),
84    #[error("UTF-8 error: {0}")]
85    Utf8Error(#[from] FromUtf8Error),
86    #[error("JSON error: {0}")]
87    JsonError(#[from] serde_json::Error),
88}
89
90#[derive(Debug, Deserialize)]
91struct ApiErrorResponse {
92    message: String,
93}
94
95#[derive(Debug, Deserialize, Clone, Serialize)]
96pub struct RawMessage {
97    pub role: String,
98    pub content: String,
99}
100
101const MIRA_API_BASE_URL: &str = "https://api.mira.network";
102
103impl TryFrom<RawMessage> for message::Message {
104    type Error = CompletionError;
105
106    fn try_from(raw: RawMessage) -> Result<Self, Self::Error> {
107        match raw.role.as_str() {
108            "user" => Ok(message::Message::User {
109                content: OneOrMany::one(UserContent::Text(message::Text { text: raw.content })),
110            }),
111            "assistant" => Ok(message::Message::Assistant {
112                id: None,
113                content: OneOrMany::one(AssistantContent::Text(message::Text {
114                    text: raw.content,
115                })),
116            }),
117            _ => Err(CompletionError::ResponseError(format!(
118                "Unsupported message role: {}",
119                raw.role
120            ))),
121        }
122    }
123}
124
125#[derive(Debug, Deserialize, Serialize)]
126#[serde(untagged)]
127pub enum CompletionResponse {
128    Structured {
129        id: String,
130        object: String,
131        created: u64,
132        model: String,
133        choices: Vec<ChatChoice>,
134        #[serde(skip_serializing_if = "Option::is_none")]
135        usage: Option<Usage>,
136    },
137    Simple(String),
138}
139
140#[derive(Debug, Deserialize, Serialize)]
141pub struct ChatChoice {
142    pub message: RawMessage,
143    #[serde(default)]
144    pub finish_reason: Option<String>,
145    #[serde(default)]
146    pub index: Option<usize>,
147}
148
149#[derive(Debug, Deserialize, Serialize)]
150struct ModelsResponse {
151    data: Vec<ModelInfo>,
152}
153
154#[derive(Debug, Deserialize, Serialize)]
155struct ModelInfo {
156    id: String,
157}
158
159impl<T> Client<T>
160where
161    T: HttpClientExt + 'static,
162{
163    /// List available models
164    pub async fn list_models(&self) -> Result<Vec<String>, MiraError> {
165        let req = self.get("/v1/models").and_then(|req| {
166            req.body(http_client::NoBody)
167                .map_err(http_client::Error::Protocol)
168        })?;
169
170        let response = self.send(req).await?;
171
172        let status = response.status();
173
174        if !status.is_success() {
175            // Log the error text but don't store it in an unused variable
176            let error_text = http_client::text(response).await.unwrap_or_default();
177            tracing::error!("Error response: {}", error_text);
178            return Err(MiraError::ApiError(status.as_u16()));
179        }
180
181        let response_text = http_client::text(response).await?;
182
183        let models: ModelsResponse = serde_json::from_str(&response_text).map_err(|e| {
184            tracing::error!("Failed to parse response: {}", e);
185            MiraError::JsonError(e)
186        })?;
187
188        Ok(models.data.into_iter().map(|model| model.id).collect())
189    }
190}
191
192impl ProviderClient for Client {
193    type Input = String;
194
195    /// Create a new Mira client from the `MIRA_API_KEY` environment variable.
196    /// Panics if the environment variable is not set.
197    fn from_env() -> Self {
198        let api_key = std::env::var("MIRA_API_KEY").expect("MIRA_API_KEY not set");
199        Self::new(&api_key).unwrap()
200    }
201
202    fn from_val(input: Self::Input) -> Self {
203        Self::new(&input).unwrap()
204    }
205}
206
207#[derive(Debug, Serialize, Deserialize)]
208pub(super) struct MiraCompletionRequest {
209    model: String,
210    pub messages: Vec<RawMessage>,
211    #[serde(flatten, skip_serializing_if = "Option::is_none")]
212    temperature: Option<f64>,
213    #[serde(flatten, skip_serializing_if = "Option::is_none")]
214    max_tokens: Option<u64>,
215    pub stream: bool,
216}
217
218impl TryFrom<(&str, CompletionRequest)> for MiraCompletionRequest {
219    type Error = CompletionError;
220
221    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
222        let mut messages = Vec::new();
223
224        if let Some(content) = &req.preamble {
225            messages.push(RawMessage {
226                role: "user".to_string(),
227                content: content.to_string(),
228            });
229        }
230
231        if let Some(Message::User { content }) = req.normalized_documents() {
232            let text = content
233                .into_iter()
234                .filter_map(|doc| match doc {
235                    UserContent::Document(Document {
236                        data: DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data),
237                        ..
238                    }) => Some(data),
239                    UserContent::Text(text) => Some(text.text),
240
241                    // This should always be `Document`
242                    _ => None,
243                })
244                .collect::<Vec<_>>()
245                .join("\n");
246
247            messages.push(RawMessage {
248                role: "user".to_string(),
249                content: text,
250            });
251        }
252
253        for msg in req.chat_history {
254            let (role, content) = match msg {
255                Message::User { content } => {
256                    let text = content
257                        .iter()
258                        .map(|c| match c {
259                            UserContent::Text(text) => &text.text,
260                            _ => "",
261                        })
262                        .collect::<Vec<_>>()
263                        .join("\n");
264                    ("user", text)
265                }
266                Message::Assistant { content, .. } => {
267                    let text = content
268                        .iter()
269                        .map(|c| match c {
270                            AssistantContent::Text(text) => &text.text,
271                            _ => "",
272                        })
273                        .collect::<Vec<_>>()
274                        .join("\n");
275                    ("assistant", text)
276                }
277            };
278            messages.push(RawMessage {
279                role: role.to_string(),
280                content,
281            });
282        }
283
284        Ok(Self {
285            model: model.to_string(),
286            messages,
287            temperature: req.temperature,
288            max_tokens: req.max_tokens,
289            stream: false,
290        })
291    }
292}
293
294#[derive(Clone)]
295pub struct CompletionModel<T = reqwest::Client> {
296    client: Client<T>,
297    /// Name of the model
298    pub model: String,
299}
300
301impl<T> CompletionModel<T> {
302    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
303        Self {
304            client,
305            model: model.into(),
306        }
307    }
308}
309
310impl<T> completion::CompletionModel for CompletionModel<T>
311where
312    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
313{
314    type Response = CompletionResponse;
315    type StreamingResponse = openai::StreamingCompletionResponse;
316
317    type Client = Client<T>;
318
319    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
320        Self::new(client.clone(), model)
321    }
322
323    #[cfg_attr(feature = "worker", worker::send)]
324    async fn completion(
325        &self,
326        completion_request: CompletionRequest,
327    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
328        let span = if tracing::Span::current().is_disabled() {
329            info_span!(
330                target: "rig::completions",
331                "chat",
332                gen_ai.operation.name = "chat",
333                gen_ai.provider.name = "mira",
334                gen_ai.request.model = self.model,
335                gen_ai.system_instructions = tracing::field::Empty,
336                gen_ai.response.id = tracing::field::Empty,
337                gen_ai.response.model = tracing::field::Empty,
338                gen_ai.usage.output_tokens = tracing::field::Empty,
339                gen_ai.usage.input_tokens = tracing::field::Empty,
340            )
341        } else {
342            tracing::Span::current()
343        };
344
345        span.record("gen_ai.system_instructions", &completion_request.preamble);
346
347        if !completion_request.tools.is_empty() {
348            tracing::warn!(target: "rig::completions",
349                "Tool calls are not supported by Mira AI. {len} tools will be ignored.",
350                len = completion_request.tools.len()
351            );
352        }
353
354        if completion_request.tool_choice.is_some() {
355            tracing::warn!("WARNING: `tool_choice` not supported on Mira AI");
356        }
357
358        if completion_request.additional_params.is_some() {
359            tracing::warn!("WARNING: Additional parameters not supported on Mira AI");
360        }
361
362        let request = MiraCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
363
364        if tracing::enabled!(tracing::Level::TRACE) {
365            tracing::trace!(target: "rig::completions",
366                "Mira completion request: {}",
367                serde_json::to_string_pretty(&request)?
368            );
369        }
370
371        let body = serde_json::to_vec(&request)?;
372
373        let req = self
374            .client
375            .post("/v1/chat/completions")?
376            .body(body)
377            .map_err(http_client::Error::from)?;
378
379        let async_block = async move {
380            let response = self
381                .client
382                .send::<_, bytes::Bytes>(req)
383                .await
384                .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
385
386            let status = response.status();
387            let response_body = response.into_body().into_future().await?.to_vec();
388
389            if !status.is_success() {
390                let status = status.as_u16();
391                let error_text = String::from_utf8_lossy(&response_body).to_string();
392                return Err(CompletionError::ProviderError(format!(
393                    "API error: {status} - {error_text}"
394                )));
395            }
396
397            let response: CompletionResponse = serde_json::from_slice(&response_body)?;
398
399            if tracing::enabled!(tracing::Level::TRACE) {
400                tracing::trace!(target: "rig::completions",
401                    "Mira completion response: {}",
402                    serde_json::to_string_pretty(&response)?
403                );
404            }
405
406            if let CompletionResponse::Structured {
407                id, model, usage, ..
408            } = &response
409            {
410                let span = tracing::Span::current();
411                span.record("gen_ai.response.model_name", model);
412                span.record("gen_ai.response.id", id);
413                if let Some(usage) = usage {
414                    span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
415                    span.record(
416                        "gen_ai.usage.output_tokens",
417                        usage.total_tokens - usage.prompt_tokens,
418                    );
419                }
420            }
421
422            response.try_into()
423        };
424
425        async_block.instrument(span).await
426    }
427
428    #[cfg_attr(feature = "worker", worker::send)]
429    async fn stream(
430        &self,
431        completion_request: CompletionRequest,
432    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
433        let span = if tracing::Span::current().is_disabled() {
434            info_span!(
435                target: "rig::completions",
436                "chat_streaming",
437                gen_ai.operation.name = "chat_streaming",
438                gen_ai.provider.name = "mira",
439                gen_ai.request.model = self.model,
440                gen_ai.system_instructions = tracing::field::Empty,
441                gen_ai.response.id = tracing::field::Empty,
442                gen_ai.response.model = tracing::field::Empty,
443                gen_ai.usage.output_tokens = tracing::field::Empty,
444                gen_ai.usage.input_tokens = tracing::field::Empty,
445            )
446        } else {
447            tracing::Span::current()
448        };
449
450        span.record("gen_ai.system_instructions", &completion_request.preamble);
451
452        if !completion_request.tools.is_empty() {
453            tracing::warn!(target: "rig::completions",
454                "Tool calls are not supported by Mira AI. {len} tools will be ignored.",
455                len = completion_request.tools.len()
456            );
457        }
458
459        if completion_request.tool_choice.is_some() {
460            tracing::warn!("WARNING: `tool_choice` not supported on Mira AI");
461        }
462
463        if completion_request.additional_params.is_some() {
464            tracing::warn!("WARNING: Additional parameters not supported on Mira AI");
465        }
466        let mut request =
467            MiraCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
468        request.stream = true;
469
470        if tracing::enabled!(tracing::Level::TRACE) {
471            tracing::trace!(target: "rig::completions",
472                "Mira completion request: {}",
473                serde_json::to_string_pretty(&request)?
474            );
475        }
476
477        let body = serde_json::to_vec(&request)?;
478
479        let req = self
480            .client
481            .post("/v1/chat/completions")?
482            .body(body)
483            .map_err(http_client::Error::from)?;
484
485        send_compatible_streaming_request(self.client.clone(), req)
486            .instrument(span)
487            .await
488    }
489}
490
491impl From<ApiErrorResponse> for CompletionError {
492    fn from(err: ApiErrorResponse) -> Self {
493        CompletionError::ProviderError(err.message)
494    }
495}
496
497impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
498    type Error = CompletionError;
499
500    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
501        let (content, usage) = match &response {
502            CompletionResponse::Structured { choices, usage, .. } => {
503                let choice = choices.first().ok_or_else(|| {
504                    CompletionError::ResponseError("Response contained no choices".to_owned())
505                })?;
506
507                let usage = usage
508                    .as_ref()
509                    .map(|usage| completion::Usage {
510                        input_tokens: usage.prompt_tokens as u64,
511                        output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
512                        total_tokens: usage.total_tokens as u64,
513                    })
514                    .unwrap_or_default();
515
516                // Convert RawMessage to message::Message
517                let message = message::Message::try_from(choice.message.clone())?;
518
519                let content = match message {
520                    Message::Assistant { content, .. } => {
521                        if content.is_empty() {
522                            return Err(CompletionError::ResponseError(
523                                "Response contained empty content".to_owned(),
524                            ));
525                        }
526
527                        // Log warning for unsupported content types
528                        for c in content.iter() {
529                            if !matches!(c, AssistantContent::Text(_)) {
530                                tracing::warn!(target: "rig",
531                                    "Unsupported content type encountered: {:?}. The Mira provider currently only supports text content", c
532                                );
533                            }
534                        }
535
536                        content.iter().map(|c| {
537                            match c {
538                                AssistantContent::Text(text) => Ok(completion::AssistantContent::text(&text.text)),
539                                other => Err(CompletionError::ResponseError(
540                                    format!("Unsupported content type: {other:?}. The Mira provider currently only supports text content")
541                                ))
542                            }
543                        }).collect::<Result<Vec<_>, _>>()?
544                    }
545                    Message::User { .. } => {
546                        tracing::warn!(target: "rig", "Received user message in response where assistant message was expected");
547                        return Err(CompletionError::ResponseError(
548                            "Received user message in response where assistant message was expected".to_owned()
549                        ));
550                    }
551                };
552
553                (content, usage)
554            }
555            CompletionResponse::Simple(text) => (
556                vec![completion::AssistantContent::text(text)],
557                completion::Usage::new(),
558            ),
559        };
560
561        let choice = OneOrMany::many(content).map_err(|_| {
562            CompletionError::ResponseError(
563                "Response contained no message or tool call (empty)".to_owned(),
564            )
565        })?;
566
567        Ok(completion::CompletionResponse {
568            choice,
569            usage,
570            raw_response: response,
571        })
572    }
573}
574
575#[derive(Clone, Debug, Deserialize, Serialize)]
576pub struct Usage {
577    pub prompt_tokens: usize,
578    pub total_tokens: usize,
579}
580
581impl std::fmt::Display for Usage {
582    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
583        write!(
584            f,
585            "Prompt tokens: {} Total tokens: {}",
586            self.prompt_tokens, self.total_tokens
587        )
588    }
589}
590
591impl From<Message> for serde_json::Value {
592    fn from(msg: Message) -> Self {
593        match msg {
594            Message::User { content } => {
595                let text = content
596                    .iter()
597                    .map(|c| match c {
598                        UserContent::Text(text) => &text.text,
599                        _ => "",
600                    })
601                    .collect::<Vec<_>>()
602                    .join("\n");
603                serde_json::json!({
604                    "role": "user",
605                    "content": text
606                })
607            }
608            Message::Assistant { content, .. } => {
609                let text = content
610                    .iter()
611                    .map(|c| match c {
612                        AssistantContent::Text(text) => &text.text,
613                        _ => "",
614                    })
615                    .collect::<Vec<_>>()
616                    .join("\n");
617                serde_json::json!({
618                    "role": "assistant",
619                    "content": text
620                })
621            }
622        }
623    }
624}
625
626impl TryFrom<serde_json::Value> for Message {
627    type Error = CompletionError;
628
629    fn try_from(value: serde_json::Value) -> Result<Self, Self::Error> {
630        let role = value["role"].as_str().ok_or_else(|| {
631            CompletionError::ResponseError("Message missing role field".to_owned())
632        })?;
633
634        // Handle both string and array content formats
635        let content = match value.get("content") {
636            Some(content) => match content {
637                serde_json::Value::String(s) => s.clone(),
638                serde_json::Value::Array(arr) => arr
639                    .iter()
640                    .filter_map(|c| {
641                        c.get("text")
642                            .and_then(|t| t.as_str())
643                            .map(|text| text.to_string())
644                    })
645                    .collect::<Vec<_>>()
646                    .join("\n"),
647                _ => {
648                    return Err(CompletionError::ResponseError(
649                        "Message content must be string or array".to_owned(),
650                    ));
651                }
652            },
653            None => {
654                return Err(CompletionError::ResponseError(
655                    "Message missing content field".to_owned(),
656                ));
657            }
658        };
659
660        match role {
661            "user" => Ok(Message::User {
662                content: OneOrMany::one(UserContent::Text(message::Text { text: content })),
663            }),
664            "assistant" => Ok(Message::Assistant {
665                id: None,
666                content: OneOrMany::one(AssistantContent::Text(message::Text { text: content })),
667            }),
668            _ => Err(CompletionError::ResponseError(format!(
669                "Unsupported message role: {role}"
670            ))),
671        }
672    }
673}
674
675#[cfg(test)]
676mod tests {
677    use super::*;
678    use crate::message::UserContent;
679    use serde_json::json;
680
681    #[test]
682    fn test_deserialize_message() {
683        // Test string content format
684        let assistant_message_json = json!({
685            "role": "assistant",
686            "content": "Hello there, how may I assist you today?"
687        });
688
689        let user_message_json = json!({
690            "role": "user",
691            "content": "What can you help me with?"
692        });
693
694        // Test array content format
695        let assistant_message_array_json = json!({
696            "role": "assistant",
697            "content": [{
698                "type": "text",
699                "text": "Hello there, how may I assist you today?"
700            }]
701        });
702
703        let assistant_message = Message::try_from(assistant_message_json).unwrap();
704        let user_message = Message::try_from(user_message_json).unwrap();
705        let assistant_message_array = Message::try_from(assistant_message_array_json).unwrap();
706
707        // Test string content format
708        match assistant_message {
709            Message::Assistant { content, .. } => {
710                assert_eq!(
711                    content.first(),
712                    AssistantContent::Text(message::Text {
713                        text: "Hello there, how may I assist you today?".to_string()
714                    })
715                );
716            }
717            _ => panic!("Expected assistant message"),
718        }
719
720        match user_message {
721            Message::User { content } => {
722                assert_eq!(
723                    content.first(),
724                    UserContent::Text(message::Text {
725                        text: "What can you help me with?".to_string()
726                    })
727                );
728            }
729            _ => panic!("Expected user message"),
730        }
731
732        // Test array content format
733        match assistant_message_array {
734            Message::Assistant { content, .. } => {
735                assert_eq!(
736                    content.first(),
737                    AssistantContent::Text(message::Text {
738                        text: "Hello there, how may I assist you today?".to_string()
739                    })
740                );
741            }
742            _ => panic!("Expected assistant message"),
743        }
744    }
745
746    #[test]
747    fn test_message_conversion() {
748        // Test converting from our Message type to Mira's format and back
749        let original_message = message::Message::User {
750            content: OneOrMany::one(message::UserContent::text("Hello")),
751        };
752
753        // Convert to Mira format
754        let mira_value: serde_json::Value = original_message.clone().into();
755
756        // Convert back to our Message type
757        let converted_message: Message = mira_value.try_into().unwrap();
758
759        assert_eq!(original_message, converted_message);
760    }
761
762    #[test]
763    fn test_completion_response_conversion() {
764        let mira_response = CompletionResponse::Structured {
765            id: "resp_123".to_string(),
766            object: "chat.completion".to_string(),
767            created: 1234567890,
768            model: "deepseek-r1".to_string(),
769            choices: vec![ChatChoice {
770                message: RawMessage {
771                    role: "assistant".to_string(),
772                    content: "Test response".to_string(),
773                },
774                finish_reason: Some("stop".to_string()),
775                index: Some(0),
776            }],
777            usage: Some(Usage {
778                prompt_tokens: 10,
779                total_tokens: 20,
780            }),
781        };
782
783        let completion_response: completion::CompletionResponse<CompletionResponse> =
784            mira_response.try_into().unwrap();
785
786        assert_eq!(
787            completion_response.choice.first(),
788            completion::AssistantContent::text("Test response")
789        );
790    }
791}