Skip to main content

rig/providers/
mira.rs

1//! Mira API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::mira;
6//!
7//! let client = mira::Client::new("YOUR_API_KEY");
8//!
9//! ```
10use crate::client::{
11    self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
12    ProviderClient,
13};
14use crate::http_client::{self, HttpClientExt};
15use crate::message::{Document, DocumentSourceKind};
16use crate::providers::openai;
17use crate::providers::openai::send_compatible_streaming_request;
18use crate::streaming::StreamingCompletionResponse;
19use crate::{
20    OneOrMany,
21    completion::{self, CompletionError, CompletionRequest},
22    message::{self, AssistantContent, Message, UserContent},
23};
24use serde::{Deserialize, Serialize};
25use std::string::FromUtf8Error;
26use thiserror::Error;
27use tracing::{self, Instrument, info_span};
28
29#[derive(Debug, Default, Clone, Copy)]
30pub struct MiraExt;
31#[derive(Debug, Default, Clone, Copy)]
32pub struct MiraBuilder;
33
34type MiraApiKey = BearerAuth;
35
36impl Provider for MiraExt {
37    type Builder = MiraBuilder;
38
39    const VERIFY_PATH: &'static str = "/user-credits";
40
41    fn build<H>(
42        _: &crate::client::ClientBuilder<
43            Self::Builder,
44            <Self::Builder as crate::client::ProviderBuilder>::ApiKey,
45            H,
46        >,
47    ) -> http_client::Result<Self> {
48        Ok(Self)
49    }
50}
51
52impl<H> Capabilities<H> for MiraExt {
53    type Completion = Capable<CompletionModel<H>>;
54    type Embeddings = Nothing;
55    type Transcription = Nothing;
56    type ModelListing = Nothing;
57
58    #[cfg(feature = "image")]
59    type ImageGeneration = Nothing;
60
61    #[cfg(feature = "audio")]
62    type AudioGeneration = Nothing;
63}
64
65impl DebugExt for MiraExt {}
66
67impl ProviderBuilder for MiraBuilder {
68    type Output = MiraExt;
69    type ApiKey = MiraApiKey;
70
71    const BASE_URL: &'static str = MIRA_API_BASE_URL;
72}
73
74pub type Client<H = reqwest::Client> = client::Client<MiraExt, H>;
75pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<MiraBuilder, MiraApiKey, H>;
76
77#[derive(Debug, Error)]
78pub enum MiraError {
79    #[error("Invalid API key")]
80    InvalidApiKey,
81    #[error("API error: {0}")]
82    ApiError(u16),
83    #[error("Request error: {0}")]
84    RequestError(#[from] http_client::Error),
85    #[error("UTF-8 error: {0}")]
86    Utf8Error(#[from] FromUtf8Error),
87    #[error("JSON error: {0}")]
88    JsonError(#[from] serde_json::Error),
89}
90
91#[derive(Debug, Deserialize)]
92struct ApiErrorResponse {
93    message: String,
94}
95
96#[derive(Debug, Deserialize, Clone, Serialize)]
97pub struct RawMessage {
98    pub role: String,
99    pub content: String,
100}
101
102const MIRA_API_BASE_URL: &str = "https://api.mira.network";
103
104impl TryFrom<RawMessage> for message::Message {
105    type Error = CompletionError;
106
107    fn try_from(raw: RawMessage) -> Result<Self, Self::Error> {
108        match raw.role.as_str() {
109            "user" => Ok(message::Message::User {
110                content: OneOrMany::one(UserContent::Text(message::Text { text: raw.content })),
111            }),
112            "assistant" => Ok(message::Message::Assistant {
113                id: None,
114                content: OneOrMany::one(AssistantContent::Text(message::Text {
115                    text: raw.content,
116                })),
117            }),
118            _ => Err(CompletionError::ResponseError(format!(
119                "Unsupported message role: {}",
120                raw.role
121            ))),
122        }
123    }
124}
125
126#[derive(Debug, Deserialize, Serialize)]
127#[serde(untagged)]
128pub enum CompletionResponse {
129    Structured {
130        id: String,
131        object: String,
132        created: u64,
133        model: String,
134        choices: Vec<ChatChoice>,
135        #[serde(skip_serializing_if = "Option::is_none")]
136        usage: Option<Usage>,
137    },
138    Simple(String),
139}
140
141#[derive(Debug, Deserialize, Serialize)]
142pub struct ChatChoice {
143    pub message: RawMessage,
144    #[serde(default)]
145    pub finish_reason: Option<String>,
146    #[serde(default)]
147    pub index: Option<usize>,
148}
149
150#[derive(Debug, Deserialize, Serialize)]
151struct ModelsResponse {
152    data: Vec<ModelInfo>,
153}
154
155#[derive(Debug, Deserialize, Serialize)]
156struct ModelInfo {
157    id: String,
158}
159
160impl<T> Client<T>
161where
162    T: HttpClientExt + 'static,
163{
164    /// List available models
165    pub async fn list_models(&self) -> Result<Vec<String>, MiraError> {
166        let req = self.get("/v1/models").and_then(|req| {
167            req.body(http_client::NoBody)
168                .map_err(http_client::Error::Protocol)
169        })?;
170
171        let response = self.send(req).await?;
172
173        let status = response.status();
174
175        if !status.is_success() {
176            // Log the error text but don't store it in an unused variable
177            let error_text = http_client::text(response).await.unwrap_or_default();
178            tracing::error!("Error response: {}", error_text);
179            return Err(MiraError::ApiError(status.as_u16()));
180        }
181
182        let response_text = http_client::text(response).await?;
183
184        let models: ModelsResponse = serde_json::from_str(&response_text).map_err(|e| {
185            tracing::error!("Failed to parse response: {}", e);
186            MiraError::JsonError(e)
187        })?;
188
189        Ok(models.data.into_iter().map(|model| model.id).collect())
190    }
191}
192
193impl ProviderClient for Client {
194    type Input = String;
195
196    /// Create a new Mira client from the `MIRA_API_KEY` environment variable.
197    /// Panics if the environment variable is not set.
198    fn from_env() -> Self {
199        let api_key = std::env::var("MIRA_API_KEY").expect("MIRA_API_KEY not set");
200        Self::new(&api_key).unwrap()
201    }
202
203    fn from_val(input: Self::Input) -> Self {
204        Self::new(&input).unwrap()
205    }
206}
207
208#[derive(Debug, Serialize, Deserialize)]
209pub(super) struct MiraCompletionRequest {
210    model: String,
211    pub messages: Vec<RawMessage>,
212    #[serde(skip_serializing_if = "Option::is_none")]
213    temperature: Option<f64>,
214    #[serde(skip_serializing_if = "Option::is_none")]
215    max_tokens: Option<u64>,
216    pub stream: bool,
217}
218
219impl TryFrom<(&str, CompletionRequest)> for MiraCompletionRequest {
220    type Error = CompletionError;
221
222    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
223        if req.output_schema.is_some() {
224            tracing::warn!("Structured outputs currently not supported for Mira");
225        }
226        let model = req.model.clone().unwrap_or_else(|| model.to_string());
227        let mut messages = Vec::new();
228
229        if let Some(content) = &req.preamble {
230            messages.push(RawMessage {
231                role: "user".to_string(),
232                content: content.to_string(),
233            });
234        }
235
236        if let Some(Message::User { content }) = req.normalized_documents() {
237            let text = content
238                .into_iter()
239                .filter_map(|doc| match doc {
240                    UserContent::Document(Document {
241                        data: DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data),
242                        ..
243                    }) => Some(data),
244                    UserContent::Text(text) => Some(text.text),
245
246                    // This should always be `Document`
247                    _ => None,
248                })
249                .collect::<Vec<_>>()
250                .join("\n");
251
252            messages.push(RawMessage {
253                role: "user".to_string(),
254                content: text,
255            });
256        }
257
258        for msg in req.chat_history {
259            let (role, content) = match msg {
260                Message::User { content } => {
261                    let text = content
262                        .iter()
263                        .map(|c| match c {
264                            UserContent::Text(text) => &text.text,
265                            _ => "",
266                        })
267                        .collect::<Vec<_>>()
268                        .join("\n");
269                    ("user", text)
270                }
271                Message::Assistant { content, .. } => {
272                    let text = content
273                        .iter()
274                        .map(|c| match c {
275                            AssistantContent::Text(text) => &text.text,
276                            _ => "",
277                        })
278                        .collect::<Vec<_>>()
279                        .join("\n");
280                    ("assistant", text)
281                }
282            };
283            messages.push(RawMessage {
284                role: role.to_string(),
285                content,
286            });
287        }
288
289        Ok(Self {
290            model: model.to_string(),
291            messages,
292            temperature: req.temperature,
293            max_tokens: req.max_tokens,
294            stream: false,
295        })
296    }
297}
298
299#[derive(Clone)]
300pub struct CompletionModel<T = reqwest::Client> {
301    client: Client<T>,
302    /// Name of the model
303    pub model: String,
304}
305
306impl<T> CompletionModel<T> {
307    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
308        Self {
309            client,
310            model: model.into(),
311        }
312    }
313}
314
315impl<T> completion::CompletionModel for CompletionModel<T>
316where
317    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
318{
319    type Response = CompletionResponse;
320    type StreamingResponse = openai::StreamingCompletionResponse;
321
322    type Client = Client<T>;
323
324    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
325        Self::new(client.clone(), model)
326    }
327
328    async fn completion(
329        &self,
330        completion_request: CompletionRequest,
331    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
332        let span = if tracing::Span::current().is_disabled() {
333            info_span!(
334                target: "rig::completions",
335                "chat",
336                gen_ai.operation.name = "chat",
337                gen_ai.provider.name = "mira",
338                gen_ai.request.model = self.model,
339                gen_ai.system_instructions = tracing::field::Empty,
340                gen_ai.response.id = tracing::field::Empty,
341                gen_ai.response.model = tracing::field::Empty,
342                gen_ai.usage.output_tokens = tracing::field::Empty,
343                gen_ai.usage.input_tokens = tracing::field::Empty,
344            )
345        } else {
346            tracing::Span::current()
347        };
348
349        span.record("gen_ai.system_instructions", &completion_request.preamble);
350
351        if !completion_request.tools.is_empty() {
352            tracing::warn!(target: "rig::completions",
353                "Tool calls are not supported by Mira AI. {len} tools will be ignored.",
354                len = completion_request.tools.len()
355            );
356        }
357
358        if completion_request.tool_choice.is_some() {
359            tracing::warn!("WARNING: `tool_choice` not supported on Mira AI");
360        }
361
362        if completion_request.additional_params.is_some() {
363            tracing::warn!("WARNING: Additional parameters not supported on Mira AI");
364        }
365
366        let request = MiraCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
367
368        if tracing::enabled!(tracing::Level::TRACE) {
369            tracing::trace!(target: "rig::completions",
370                "Mira completion request: {}",
371                serde_json::to_string_pretty(&request)?
372            );
373        }
374
375        let body = serde_json::to_vec(&request)?;
376
377        let req = self
378            .client
379            .post("/v1/chat/completions")?
380            .body(body)
381            .map_err(http_client::Error::from)?;
382
383        let async_block = async move {
384            let response = self
385                .client
386                .send::<_, bytes::Bytes>(req)
387                .await
388                .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
389
390            let status = response.status();
391            let response_body = response.into_body().into_future().await?.to_vec();
392
393            if !status.is_success() {
394                let status = status.as_u16();
395                let error_text = String::from_utf8_lossy(&response_body).to_string();
396                return Err(CompletionError::ProviderError(format!(
397                    "API error: {status} - {error_text}"
398                )));
399            }
400
401            let response: CompletionResponse = serde_json::from_slice(&response_body)?;
402
403            if tracing::enabled!(tracing::Level::TRACE) {
404                tracing::trace!(target: "rig::completions",
405                    "Mira completion response: {}",
406                    serde_json::to_string_pretty(&response)?
407                );
408            }
409
410            if let CompletionResponse::Structured {
411                id, model, usage, ..
412            } = &response
413            {
414                let span = tracing::Span::current();
415                span.record("gen_ai.response.model_name", model);
416                span.record("gen_ai.response.id", id);
417                if let Some(usage) = usage {
418                    span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
419                    span.record(
420                        "gen_ai.usage.output_tokens",
421                        usage.total_tokens - usage.prompt_tokens,
422                    );
423                }
424            }
425
426            response.try_into()
427        };
428
429        async_block.instrument(span).await
430    }
431
432    async fn stream(
433        &self,
434        completion_request: CompletionRequest,
435    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
436        let span = if tracing::Span::current().is_disabled() {
437            info_span!(
438                target: "rig::completions",
439                "chat_streaming",
440                gen_ai.operation.name = "chat_streaming",
441                gen_ai.provider.name = "mira",
442                gen_ai.request.model = self.model,
443                gen_ai.system_instructions = tracing::field::Empty,
444                gen_ai.response.id = tracing::field::Empty,
445                gen_ai.response.model = tracing::field::Empty,
446                gen_ai.usage.output_tokens = tracing::field::Empty,
447                gen_ai.usage.input_tokens = tracing::field::Empty,
448            )
449        } else {
450            tracing::Span::current()
451        };
452
453        span.record("gen_ai.system_instructions", &completion_request.preamble);
454
455        if !completion_request.tools.is_empty() {
456            tracing::warn!(target: "rig::completions",
457                "Tool calls are not supported by Mira AI. {len} tools will be ignored.",
458                len = completion_request.tools.len()
459            );
460        }
461
462        if completion_request.tool_choice.is_some() {
463            tracing::warn!("WARNING: `tool_choice` not supported on Mira AI");
464        }
465
466        if completion_request.additional_params.is_some() {
467            tracing::warn!("WARNING: Additional parameters not supported on Mira AI");
468        }
469        let mut request =
470            MiraCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
471        request.stream = true;
472
473        if tracing::enabled!(tracing::Level::TRACE) {
474            tracing::trace!(target: "rig::completions",
475                "Mira completion request: {}",
476                serde_json::to_string_pretty(&request)?
477            );
478        }
479
480        let body = serde_json::to_vec(&request)?;
481
482        let req = self
483            .client
484            .post("/v1/chat/completions")?
485            .body(body)
486            .map_err(http_client::Error::from)?;
487
488        send_compatible_streaming_request(self.client.clone(), req)
489            .instrument(span)
490            .await
491    }
492}
493
494impl From<ApiErrorResponse> for CompletionError {
495    fn from(err: ApiErrorResponse) -> Self {
496        CompletionError::ProviderError(err.message)
497    }
498}
499
500impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
501    type Error = CompletionError;
502
503    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
504        let (content, usage) = match &response {
505            CompletionResponse::Structured { choices, usage, .. } => {
506                let choice = choices.first().ok_or_else(|| {
507                    CompletionError::ResponseError("Response contained no choices".to_owned())
508                })?;
509
510                let usage = usage
511                    .as_ref()
512                    .map(|usage| completion::Usage {
513                        input_tokens: usage.prompt_tokens as u64,
514                        output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
515                        total_tokens: usage.total_tokens as u64,
516                        cached_input_tokens: 0,
517                    })
518                    .unwrap_or_default();
519
520                // Convert RawMessage to message::Message
521                let message = message::Message::try_from(choice.message.clone())?;
522
523                let content = match message {
524                    Message::Assistant { content, .. } => {
525                        if content.is_empty() {
526                            return Err(CompletionError::ResponseError(
527                                "Response contained empty content".to_owned(),
528                            ));
529                        }
530
531                        // Log warning for unsupported content types
532                        for c in content.iter() {
533                            if !matches!(c, AssistantContent::Text(_)) {
534                                tracing::warn!(target: "rig",
535                                    "Unsupported content type encountered: {:?}. The Mira provider currently only supports text content", c
536                                );
537                            }
538                        }
539
540                        content.iter().map(|c| {
541                            match c {
542                                AssistantContent::Text(text) => Ok(completion::AssistantContent::text(&text.text)),
543                                other => Err(CompletionError::ResponseError(
544                                    format!("Unsupported content type: {other:?}. The Mira provider currently only supports text content")
545                                ))
546                            }
547                        }).collect::<Result<Vec<_>, _>>()?
548                    }
549                    Message::User { .. } => {
550                        tracing::warn!(target: "rig", "Received user message in response where assistant message was expected");
551                        return Err(CompletionError::ResponseError(
552                            "Received user message in response where assistant message was expected".to_owned()
553                        ));
554                    }
555                };
556
557                (content, usage)
558            }
559            CompletionResponse::Simple(text) => (
560                vec![completion::AssistantContent::text(text)],
561                completion::Usage::new(),
562            ),
563        };
564
565        let choice = OneOrMany::many(content).map_err(|_| {
566            CompletionError::ResponseError(
567                "Response contained no message or tool call (empty)".to_owned(),
568            )
569        })?;
570
571        Ok(completion::CompletionResponse {
572            choice,
573            usage,
574            raw_response: response,
575            message_id: None,
576        })
577    }
578}
579
580#[derive(Clone, Debug, Deserialize, Serialize)]
581pub struct Usage {
582    pub prompt_tokens: usize,
583    pub total_tokens: usize,
584}
585
586impl std::fmt::Display for Usage {
587    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
588        write!(
589            f,
590            "Prompt tokens: {} Total tokens: {}",
591            self.prompt_tokens, self.total_tokens
592        )
593    }
594}
595
596impl From<Message> for serde_json::Value {
597    fn from(msg: Message) -> Self {
598        match msg {
599            Message::User { content } => {
600                let text = content
601                    .iter()
602                    .map(|c| match c {
603                        UserContent::Text(text) => &text.text,
604                        _ => "",
605                    })
606                    .collect::<Vec<_>>()
607                    .join("\n");
608                serde_json::json!({
609                    "role": "user",
610                    "content": text
611                })
612            }
613            Message::Assistant { content, .. } => {
614                let text = content
615                    .iter()
616                    .map(|c| match c {
617                        AssistantContent::Text(text) => &text.text,
618                        _ => "",
619                    })
620                    .collect::<Vec<_>>()
621                    .join("\n");
622                serde_json::json!({
623                    "role": "assistant",
624                    "content": text
625                })
626            }
627        }
628    }
629}
630
631impl TryFrom<serde_json::Value> for Message {
632    type Error = CompletionError;
633
634    fn try_from(value: serde_json::Value) -> Result<Self, Self::Error> {
635        let role = value["role"].as_str().ok_or_else(|| {
636            CompletionError::ResponseError("Message missing role field".to_owned())
637        })?;
638
639        // Handle both string and array content formats
640        let content = match value.get("content") {
641            Some(content) => match content {
642                serde_json::Value::String(s) => s.clone(),
643                serde_json::Value::Array(arr) => arr
644                    .iter()
645                    .filter_map(|c| {
646                        c.get("text")
647                            .and_then(|t| t.as_str())
648                            .map(|text| text.to_string())
649                    })
650                    .collect::<Vec<_>>()
651                    .join("\n"),
652                _ => {
653                    return Err(CompletionError::ResponseError(
654                        "Message content must be string or array".to_owned(),
655                    ));
656                }
657            },
658            None => {
659                return Err(CompletionError::ResponseError(
660                    "Message missing content field".to_owned(),
661                ));
662            }
663        };
664
665        match role {
666            "user" => Ok(Message::User {
667                content: OneOrMany::one(UserContent::Text(message::Text { text: content })),
668            }),
669            "assistant" => Ok(Message::Assistant {
670                id: None,
671                content: OneOrMany::one(AssistantContent::Text(message::Text { text: content })),
672            }),
673            _ => Err(CompletionError::ResponseError(format!(
674                "Unsupported message role: {role}"
675            ))),
676        }
677    }
678}
679
680#[cfg(test)]
681mod tests {
682    use super::*;
683    use crate::message::UserContent;
684    use serde_json::json;
685
686    #[test]
687    fn test_deserialize_message() {
688        // Test string content format
689        let assistant_message_json = json!({
690            "role": "assistant",
691            "content": "Hello there, how may I assist you today?"
692        });
693
694        let user_message_json = json!({
695            "role": "user",
696            "content": "What can you help me with?"
697        });
698
699        // Test array content format
700        let assistant_message_array_json = json!({
701            "role": "assistant",
702            "content": [{
703                "type": "text",
704                "text": "Hello there, how may I assist you today?"
705            }]
706        });
707
708        let assistant_message = Message::try_from(assistant_message_json).unwrap();
709        let user_message = Message::try_from(user_message_json).unwrap();
710        let assistant_message_array = Message::try_from(assistant_message_array_json).unwrap();
711
712        // Test string content format
713        match assistant_message {
714            Message::Assistant { content, .. } => {
715                assert_eq!(
716                    content.first(),
717                    AssistantContent::Text(message::Text {
718                        text: "Hello there, how may I assist you today?".to_string()
719                    })
720                );
721            }
722            _ => panic!("Expected assistant message"),
723        }
724
725        match user_message {
726            Message::User { content } => {
727                assert_eq!(
728                    content.first(),
729                    UserContent::Text(message::Text {
730                        text: "What can you help me with?".to_string()
731                    })
732                );
733            }
734            _ => panic!("Expected user message"),
735        }
736
737        // Test array content format
738        match assistant_message_array {
739            Message::Assistant { content, .. } => {
740                assert_eq!(
741                    content.first(),
742                    AssistantContent::Text(message::Text {
743                        text: "Hello there, how may I assist you today?".to_string()
744                    })
745                );
746            }
747            _ => panic!("Expected assistant message"),
748        }
749    }
750
751    #[test]
752    fn test_message_conversion() {
753        // Test converting from our Message type to Mira's format and back
754        let original_message = message::Message::User {
755            content: OneOrMany::one(message::UserContent::text("Hello")),
756        };
757
758        // Convert to Mira format
759        let mira_value: serde_json::Value = original_message.clone().into();
760
761        // Convert back to our Message type
762        let converted_message: Message = mira_value.try_into().unwrap();
763
764        assert_eq!(original_message, converted_message);
765    }
766
767    #[test]
768    fn test_completion_response_conversion() {
769        let mira_response = CompletionResponse::Structured {
770            id: "resp_123".to_string(),
771            object: "chat.completion".to_string(),
772            created: 1234567890,
773            model: "deepseek-r1".to_string(),
774            choices: vec![ChatChoice {
775                message: RawMessage {
776                    role: "assistant".to_string(),
777                    content: "Test response".to_string(),
778                },
779                finish_reason: Some("stop".to_string()),
780                index: Some(0),
781            }],
782            usage: Some(Usage {
783                prompt_tokens: 10,
784                total_tokens: 20,
785            }),
786        };
787
788        let completion_response: completion::CompletionResponse<CompletionResponse> =
789            mira_response.try_into().unwrap();
790
791        assert_eq!(
792            completion_response.choice.first(),
793            completion::AssistantContent::text("Test response")
794        );
795    }
796    #[test]
797    fn test_client_initialization() {
798        let _client: crate::providers::mira::Client =
799            crate::providers::mira::Client::new("dummy-key").expect("Client::new() failed");
800        let _client_from_builder: crate::providers::mira::Client =
801            crate::providers::mira::Client::builder()
802                .api_key("dummy-key")
803                .build()
804                .expect("Client::builder() failed");
805    }
806}