1use crate::client::{
11 self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
12 ProviderClient,
13};
14use crate::http_client::{self, HttpClientExt};
15use crate::message::{Document, DocumentSourceKind};
16use crate::providers::openai;
17use crate::providers::openai::send_compatible_streaming_request;
18use crate::streaming::StreamingCompletionResponse;
19use crate::{
20 OneOrMany,
21 completion::{self, CompletionError, CompletionRequest},
22 message::{self, AssistantContent, Message, UserContent},
23};
24use serde::{Deserialize, Serialize};
25use std::string::FromUtf8Error;
26use thiserror::Error;
27use tracing::{self, Instrument, info_span};
28
29#[derive(Debug, Default, Clone, Copy)]
30pub struct MiraExt;
31#[derive(Debug, Default, Clone, Copy)]
32pub struct MiraBuilder;
33
34type MiraApiKey = BearerAuth;
35
36impl Provider for MiraExt {
37 type Builder = MiraBuilder;
38
39 const VERIFY_PATH: &'static str = "/user-credits";
40}
41
42impl<H> Capabilities<H> for MiraExt {
43 type Completion = Capable<CompletionModel<H>>;
44 type Embeddings = Nothing;
45 type Transcription = Nothing;
46 type ModelListing = Nothing;
47
48 #[cfg(feature = "image")]
49 type ImageGeneration = Nothing;
50
51 #[cfg(feature = "audio")]
52 type AudioGeneration = Nothing;
53 type Rerank = Nothing;
54}
55
56impl DebugExt for MiraExt {}
57
58impl ProviderBuilder for MiraBuilder {
59 type Extension<H>
60 = MiraExt
61 where
62 H: HttpClientExt;
63 type ApiKey = MiraApiKey;
64
65 const BASE_URL: &'static str = MIRA_API_BASE_URL;
66
67 fn build<H>(
68 _builder: &crate::client::ClientBuilder<Self, Self::ApiKey, H>,
69 ) -> http_client::Result<Self::Extension<H>>
70 where
71 H: HttpClientExt,
72 {
73 Ok(MiraExt)
74 }
75}
76
77pub type Client<H = reqwest::Client> = client::Client<MiraExt, H>;
78pub type ClientBuilder<H = crate::markers::Missing> =
79 client::ClientBuilder<MiraBuilder, MiraApiKey, H>;
80
81#[derive(Debug, Error)]
82pub enum MiraError {
83 #[error("Invalid API key")]
84 InvalidApiKey,
85 #[error("API error: {0}")]
86 ApiError(u16),
87 #[error("Request error: {0}")]
88 RequestError(#[from] http_client::Error),
89 #[error("UTF-8 error: {0}")]
90 Utf8Error(#[from] FromUtf8Error),
91 #[error("JSON error: {0}")]
92 JsonError(#[from] serde_json::Error),
93}
94
95#[derive(Debug, Deserialize)]
96struct ApiErrorResponse {
97 message: String,
98}
99
100#[derive(Debug, Deserialize, Clone, Serialize)]
101pub struct RawMessage {
102 pub role: String,
103 pub content: String,
104}
105
106const MIRA_API_BASE_URL: &str = "https://api.mira.network";
107
108impl TryFrom<RawMessage> for message::Message {
109 type Error = CompletionError;
110
111 fn try_from(raw: RawMessage) -> Result<Self, Self::Error> {
112 match raw.role.as_str() {
113 "system" => Ok(message::Message::System {
114 content: raw.content,
115 }),
116 "user" => Ok(message::Message::User {
117 content: OneOrMany::one(UserContent::Text(message::Text::new(raw.content))),
118 }),
119 "assistant" => Ok(message::Message::Assistant {
120 id: None,
121 content: OneOrMany::one(AssistantContent::Text(message::Text::new(raw.content))),
122 }),
123 _ => Err(CompletionError::ResponseError(format!(
124 "Unsupported message role: {}",
125 raw.role
126 ))),
127 }
128 }
129}
130
131#[derive(Debug, Deserialize, Serialize)]
132#[serde(untagged)]
133pub enum CompletionResponse {
134 Structured {
135 id: String,
136 object: String,
137 created: u64,
138 model: String,
139 choices: Vec<ChatChoice>,
140 #[serde(skip_serializing_if = "Option::is_none")]
141 usage: Option<Usage>,
142 },
143 Simple(String),
144}
145
146#[derive(Debug, Deserialize, Serialize)]
147pub struct ChatChoice {
148 pub message: RawMessage,
149 #[serde(default)]
150 pub finish_reason: Option<String>,
151 #[serde(default)]
152 pub index: Option<usize>,
153}
154
155#[derive(Debug, Deserialize, Serialize)]
156struct ModelsResponse {
157 data: Vec<ModelInfo>,
158}
159
160#[derive(Debug, Deserialize, Serialize)]
161struct ModelInfo {
162 id: String,
163}
164
165impl<T> Client<T>
166where
167 T: HttpClientExt + 'static,
168{
169 pub async fn list_models(&self) -> Result<Vec<String>, MiraError> {
171 let req = self.get("/v1/models").and_then(|req| {
172 req.body(http_client::NoBody)
173 .map_err(http_client::Error::Protocol)
174 })?;
175
176 let response = self.send(req).await?;
177
178 let status = response.status();
179
180 if !status.is_success() {
181 let error_text = http_client::text(response).await.unwrap_or_default();
183 tracing::error!("Error response: {}", error_text);
184 return Err(MiraError::ApiError(status.as_u16()));
185 }
186
187 let response_text = http_client::text(response).await?;
188
189 let models: ModelsResponse = serde_json::from_str(&response_text).map_err(|e| {
190 tracing::error!("Failed to parse response: {}", e);
191 MiraError::JsonError(e)
192 })?;
193
194 Ok(models.data.into_iter().map(|model| model.id).collect())
195 }
196}
197
198impl ProviderClient for Client {
199 type Input = String;
200 type Error = crate::client::ProviderClientError;
201
202 fn from_env() -> Result<Self, Self::Error> {
204 let api_key = crate::client::required_env_var("MIRA_API_KEY")?;
205 Self::new(&api_key).map_err(Into::into)
206 }
207
208 fn from_val(input: Self::Input) -> Result<Self, Self::Error> {
209 Self::new(&input).map_err(Into::into)
210 }
211}
212
213#[derive(Debug, Serialize, Deserialize)]
214pub(super) struct MiraCompletionRequest {
215 model: String,
216 pub messages: Vec<RawMessage>,
217 #[serde(skip_serializing_if = "Option::is_none")]
218 temperature: Option<f64>,
219 #[serde(skip_serializing_if = "Option::is_none")]
220 max_tokens: Option<u64>,
221 pub stream: bool,
222}
223
224impl TryFrom<(&str, CompletionRequest)> for MiraCompletionRequest {
225 type Error = CompletionError;
226
227 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
228 let chat_history = req.chat_history_with_documents();
229 if req.output_schema.is_some() {
230 tracing::warn!("Structured outputs currently not supported for Mira");
231 }
232 let model = req.model.clone().unwrap_or_else(|| model.to_string());
233 let mut messages = Vec::new();
234
235 if let Some(content) = &req.preamble {
236 messages.push(RawMessage {
237 role: "user".to_string(),
238 content: content.to_string(),
239 });
240 }
241
242 for msg in chat_history {
243 let (role, content) = match msg {
244 Message::System { content } => ("system", content),
245 Message::User { content } => {
246 let text =
247 content
248 .iter()
249 .map(|c| match c {
250 UserContent::Text(text) => &text.text,
251 UserContent::Document(Document {
252 data:
253 DocumentSourceKind::Base64(data)
254 | DocumentSourceKind::String(data),
255 ..
256 }) => data,
257 _ => "",
258 })
259 .collect::<Vec<_>>()
260 .join("\n");
261 ("user", text)
262 }
263 Message::Assistant { content, .. } => {
264 let text = content
265 .iter()
266 .map(|c| match c {
267 AssistantContent::Text(text) => &text.text,
268 _ => "",
269 })
270 .collect::<Vec<_>>()
271 .join("\n");
272 ("assistant", text)
273 }
274 };
275 messages.push(RawMessage {
276 role: role.to_string(),
277 content,
278 });
279 }
280
281 Ok(Self {
282 model: model.to_string(),
283 messages,
284 temperature: req.temperature,
285 max_tokens: req.max_tokens,
286 stream: false,
287 })
288 }
289}
290
291#[derive(Clone)]
292pub struct CompletionModel<T = reqwest::Client> {
293 client: Client<T>,
294 pub model: String,
296}
297
298impl<T> CompletionModel<T> {
299 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
300 Self {
301 client,
302 model: model.into(),
303 }
304 }
305}
306
307impl<T> completion::CompletionModel for CompletionModel<T>
308where
309 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
310{
311 type Response = CompletionResponse;
312 type StreamingResponse = openai::StreamingCompletionResponse;
313
314 type Client = Client<T>;
315
316 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
317 Self::new(client.clone(), model)
318 }
319
320 async fn completion(
321 &self,
322 completion_request: CompletionRequest,
323 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
324 let span = if tracing::Span::current().is_disabled() {
325 info_span!(
326 target: "rig::completions",
327 "chat",
328 gen_ai.operation.name = "chat",
329 gen_ai.provider.name = "mira",
330 gen_ai.request.model = self.model,
331 gen_ai.system_instructions = tracing::field::Empty,
332 gen_ai.response.id = tracing::field::Empty,
333 gen_ai.response.model = tracing::field::Empty,
334 gen_ai.usage.output_tokens = tracing::field::Empty,
335 gen_ai.usage.input_tokens = tracing::field::Empty,
336 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
337 )
338 } else {
339 tracing::Span::current()
340 };
341
342 span.record("gen_ai.system_instructions", &completion_request.preamble);
343
344 if !completion_request.tools.is_empty() {
345 tracing::warn!(target: "rig::completions",
346 "Tool calls are not supported by Mira AI. {len} tools will be ignored.",
347 len = completion_request.tools.len()
348 );
349 }
350
351 if completion_request.tool_choice.is_some() {
352 tracing::warn!("WARNING: `tool_choice` not supported on Mira AI");
353 }
354
355 if completion_request.additional_params.is_some() {
356 tracing::warn!("WARNING: Additional parameters not supported on Mira AI");
357 }
358
359 let request = MiraCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
360
361 if tracing::enabled!(tracing::Level::TRACE) {
362 tracing::trace!(target: "rig::completions",
363 "Mira completion request: {}",
364 serde_json::to_string_pretty(&request)?
365 );
366 }
367
368 let body = serde_json::to_vec(&request)?;
369
370 let req = self
371 .client
372 .post("/v1/chat/completions")?
373 .body(body)
374 .map_err(http_client::Error::from)?;
375
376 let async_block = async move {
377 let response = self
378 .client
379 .send::<_, bytes::Bytes>(req)
380 .await
381 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
382
383 let status = response.status();
384 let response_body = response.into_body().into_future().await?.to_vec();
385
386 if !status.is_success() {
387 let status = status.as_u16();
388 let error_text = String::from_utf8_lossy(&response_body).to_string();
389 return Err(CompletionError::ProviderError(format!(
390 "API error: {status} - {error_text}"
391 )));
392 }
393
394 let response: CompletionResponse = serde_json::from_slice(&response_body)?;
395
396 if tracing::enabled!(tracing::Level::TRACE) {
397 tracing::trace!(target: "rig::completions",
398 "Mira completion response: {}",
399 serde_json::to_string_pretty(&response)?
400 );
401 }
402
403 if let CompletionResponse::Structured {
404 id, model, usage, ..
405 } = &response
406 {
407 let span = tracing::Span::current();
408 span.record("gen_ai.response.model", model);
409 span.record("gen_ai.response.id", id);
410 if let Some(usage) = usage {
411 span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
412 span.record(
413 "gen_ai.usage.output_tokens",
414 usage.total_tokens - usage.prompt_tokens,
415 );
416 }
417 }
418
419 response.try_into()
420 };
421
422 async_block.instrument(span).await
423 }
424
425 async fn stream(
426 &self,
427 completion_request: CompletionRequest,
428 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
429 let span = if tracing::Span::current().is_disabled() {
430 info_span!(
431 target: "rig::completions",
432 "chat_streaming",
433 gen_ai.operation.name = "chat_streaming",
434 gen_ai.provider.name = "mira",
435 gen_ai.request.model = self.model,
436 gen_ai.system_instructions = tracing::field::Empty,
437 gen_ai.response.id = tracing::field::Empty,
438 gen_ai.response.model = tracing::field::Empty,
439 gen_ai.usage.output_tokens = tracing::field::Empty,
440 gen_ai.usage.input_tokens = tracing::field::Empty,
441 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
442 )
443 } else {
444 tracing::Span::current()
445 };
446
447 span.record("gen_ai.system_instructions", &completion_request.preamble);
448
449 if !completion_request.tools.is_empty() {
450 tracing::warn!(target: "rig::completions",
451 "Tool calls are not supported by Mira AI. {len} tools will be ignored.",
452 len = completion_request.tools.len()
453 );
454 }
455
456 if completion_request.tool_choice.is_some() {
457 tracing::warn!("WARNING: `tool_choice` not supported on Mira AI");
458 }
459
460 if completion_request.additional_params.is_some() {
461 tracing::warn!("WARNING: Additional parameters not supported on Mira AI");
462 }
463 let mut request =
464 MiraCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
465 request.stream = true;
466
467 if tracing::enabled!(tracing::Level::TRACE) {
468 tracing::trace!(target: "rig::completions",
469 "Mira completion request: {}",
470 serde_json::to_string_pretty(&request)?
471 );
472 }
473
474 let body = serde_json::to_vec(&request)?;
475
476 let req = self
477 .client
478 .post("/v1/chat/completions")?
479 .body(body)
480 .map_err(http_client::Error::from)?;
481
482 send_compatible_streaming_request(self.client.clone(), req)
483 .instrument(span)
484 .await
485 }
486}
487
488impl From<ApiErrorResponse> for CompletionError {
489 fn from(err: ApiErrorResponse) -> Self {
490 CompletionError::ProviderError(err.message)
491 }
492}
493
494impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
495 type Error = CompletionError;
496
497 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
498 let (content, usage) = match &response {
499 CompletionResponse::Structured { choices, usage, .. } => {
500 let choice = choices.first().ok_or_else(|| {
501 CompletionError::ResponseError("Response contained no choices".to_owned())
502 })?;
503
504 let usage = usage
505 .as_ref()
506 .map(|usage| completion::Usage {
507 input_tokens: usage.prompt_tokens as u64,
508 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
509 total_tokens: usage.total_tokens as u64,
510 cached_input_tokens: 0,
511 cache_creation_input_tokens: 0,
512 tool_use_prompt_tokens: 0,
513 reasoning_tokens: 0,
514 })
515 .unwrap_or_default();
516
517 let message = message::Message::try_from(choice.message.clone())?;
519
520 let content = match message {
521 Message::Assistant { content, .. } => {
522 if content.is_empty() {
523 return Err(CompletionError::ResponseError(
524 "Response contained empty content".to_owned(),
525 ));
526 }
527
528 for c in content.iter() {
530 if !matches!(c, AssistantContent::Text(_)) {
531 tracing::warn!(target: "rig",
532 "Unsupported content type encountered: {:?}. The Mira provider currently only supports text content", c
533 );
534 }
535 }
536
537 content.iter().map(|c| {
538 match c {
539 AssistantContent::Text(text) => Ok(completion::AssistantContent::text(&text.text)),
540 other => Err(CompletionError::ResponseError(
541 format!("Unsupported content type: {other:?}. The Mira provider currently only supports text content")
542 ))
543 }
544 }).collect::<Result<Vec<_>, _>>()?
545 }
546 Message::User { .. } => {
547 tracing::warn!(target: "rig", "Received user message in response where assistant message was expected");
548 return Err(CompletionError::ResponseError(
549 "Received user message in response where assistant message was expected".to_owned()
550 ));
551 }
552 Message::System { .. } => {
553 tracing::warn!(target: "rig", "Received system message in response where assistant message was expected");
554 return Err(CompletionError::ResponseError(
555 "Received system message in response where assistant message was expected".to_owned(),
556 ));
557 }
558 };
559
560 (content, usage)
561 }
562 CompletionResponse::Simple(text) => (
563 vec![completion::AssistantContent::text(text)],
564 completion::Usage::new(),
565 ),
566 };
567
568 let choice = OneOrMany::many(content).map_err(|_| {
569 CompletionError::ResponseError(
570 "Response contained no message or tool call (empty)".to_owned(),
571 )
572 })?;
573
574 Ok(completion::CompletionResponse {
575 choice,
576 usage,
577 raw_response: response,
578 message_id: None,
579 })
580 }
581}
582
583#[derive(Clone, Debug, Deserialize, Serialize)]
584pub struct Usage {
585 pub prompt_tokens: usize,
586 pub total_tokens: usize,
587}
588
589impl std::fmt::Display for Usage {
590 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
591 write!(
592 f,
593 "Prompt tokens: {} Total tokens: {}",
594 self.prompt_tokens, self.total_tokens
595 )
596 }
597}
598
599impl From<Message> for serde_json::Value {
600 fn from(msg: Message) -> Self {
601 match msg {
602 Message::System { content } => serde_json::json!({
603 "role": "system",
604 "content": content
605 }),
606 Message::User { content } => {
607 let text = content
608 .iter()
609 .map(|c| match c {
610 UserContent::Text(text) => &text.text,
611 _ => "",
612 })
613 .collect::<Vec<_>>()
614 .join("\n");
615 serde_json::json!({
616 "role": "user",
617 "content": text
618 })
619 }
620 Message::Assistant { content, .. } => {
621 let text = content
622 .iter()
623 .map(|c| match c {
624 AssistantContent::Text(text) => &text.text,
625 _ => "",
626 })
627 .collect::<Vec<_>>()
628 .join("\n");
629 serde_json::json!({
630 "role": "assistant",
631 "content": text
632 })
633 }
634 }
635 }
636}
637
638impl TryFrom<serde_json::Value> for Message {
639 type Error = CompletionError;
640
641 fn try_from(value: serde_json::Value) -> Result<Self, Self::Error> {
642 let role = value
643 .get("role")
644 .and_then(serde_json::Value::as_str)
645 .ok_or_else(|| {
646 CompletionError::ResponseError("Message missing role field".to_owned())
647 })?;
648
649 let content = match value.get("content") {
651 Some(content) => match content {
652 serde_json::Value::String(s) => s.clone(),
653 serde_json::Value::Array(arr) => arr
654 .iter()
655 .filter_map(|c| {
656 c.get("text")
657 .and_then(|t| t.as_str())
658 .map(|text| text.to_string())
659 })
660 .collect::<Vec<_>>()
661 .join("\n"),
662 _ => {
663 return Err(CompletionError::ResponseError(
664 "Message content must be string or array".to_owned(),
665 ));
666 }
667 },
668 None => {
669 return Err(CompletionError::ResponseError(
670 "Message missing content field".to_owned(),
671 ));
672 }
673 };
674
675 match role {
676 "system" => Ok(Message::System { content }),
677 "user" => Ok(Message::User {
678 content: OneOrMany::one(UserContent::Text(message::Text::new(content))),
679 }),
680 "assistant" => Ok(Message::Assistant {
681 id: None,
682 content: OneOrMany::one(AssistantContent::Text(message::Text::new(content))),
683 }),
684 _ => Err(CompletionError::ResponseError(format!(
685 "Unsupported message role: {role}"
686 ))),
687 }
688 }
689}
690
691#[cfg(test)]
692mod tests {
693 use super::*;
694 use crate::message::UserContent;
695 use serde_json::json;
696
697 #[test]
698 fn test_deserialize_message() {
699 let assistant_message_json = json!({
701 "role": "assistant",
702 "content": "Hello there, how may I assist you today?"
703 });
704
705 let user_message_json = json!({
706 "role": "user",
707 "content": "What can you help me with?"
708 });
709
710 let assistant_message_array_json = json!({
712 "role": "assistant",
713 "content": [{
714 "type": "text",
715 "text": "Hello there, how may I assist you today?"
716 }]
717 });
718
719 let assistant_message = Message::try_from(assistant_message_json).unwrap();
720 let user_message = Message::try_from(user_message_json).unwrap();
721 let assistant_message_array = Message::try_from(assistant_message_array_json).unwrap();
722
723 match assistant_message {
725 Message::Assistant { content, .. } => {
726 assert_eq!(
727 content.first(),
728 AssistantContent::Text(message::Text::new(
729 "Hello there, how may I assist you today?".to_string()
730 ))
731 );
732 }
733 _ => panic!("Expected assistant message"),
734 }
735
736 match user_message {
737 Message::User { content } => {
738 assert_eq!(
739 content.first(),
740 UserContent::Text(message::Text::new("What can you help me with?".to_string()))
741 );
742 }
743 _ => panic!("Expected user message"),
744 }
745
746 match assistant_message_array {
748 Message::Assistant { content, .. } => {
749 assert_eq!(
750 content.first(),
751 AssistantContent::Text(message::Text::new(
752 "Hello there, how may I assist you today?".to_string()
753 ))
754 );
755 }
756 _ => panic!("Expected assistant message"),
757 }
758 }
759
760 #[test]
761 fn test_message_conversion() {
762 let original_message = message::Message::User {
764 content: OneOrMany::one(message::UserContent::text("Hello")),
765 };
766
767 let mira_value: serde_json::Value = original_message.clone().into();
769
770 let converted_message: Message = mira_value.try_into().unwrap();
772
773 assert_eq!(original_message, converted_message);
774 }
775
776 #[test]
777 fn test_completion_response_conversion() {
778 let mira_response = CompletionResponse::Structured {
779 id: "resp_123".to_string(),
780 object: "chat.completion".to_string(),
781 created: 1234567890,
782 model: "deepseek-r1".to_string(),
783 choices: vec![ChatChoice {
784 message: RawMessage {
785 role: "assistant".to_string(),
786 content: "Test response".to_string(),
787 },
788 finish_reason: Some("stop".to_string()),
789 index: Some(0),
790 }],
791 usage: Some(Usage {
792 prompt_tokens: 10,
793 total_tokens: 20,
794 }),
795 };
796
797 let completion_response: completion::CompletionResponse<CompletionResponse> =
798 mira_response.try_into().unwrap();
799
800 assert_eq!(
801 completion_response.choice.first(),
802 completion::AssistantContent::text("Test response")
803 );
804 }
805 #[test]
806 fn test_client_initialization() {
807 let _client =
808 crate::providers::mira::Client::new("dummy-key").expect("Client::new() failed");
809 let _client_from_builder = crate::providers::mira::Client::builder()
810 .api_key("dummy-key")
811 .build()
812 .expect("Client::builder() failed");
813 }
814}