1use crate::client::{
11 self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
12 ProviderClient,
13};
14use crate::http_client::{self, HttpClientExt};
15use crate::message::{Document, DocumentSourceKind};
16use crate::providers::openai;
17use crate::providers::openai::send_compatible_streaming_request;
18use crate::streaming::StreamingCompletionResponse;
19use crate::{
20 OneOrMany,
21 completion::{self, CompletionError, CompletionRequest},
22 message::{self, AssistantContent, Message, UserContent},
23};
24use serde::{Deserialize, Serialize};
25use std::string::FromUtf8Error;
26use thiserror::Error;
27use tracing::{self, Instrument, info_span};
28
29#[derive(Debug, Default, Clone, Copy)]
30pub struct MiraExt;
31#[derive(Debug, Default, Clone, Copy)]
32pub struct MiraBuilder;
33
34type MiraApiKey = BearerAuth;
35
36impl Provider for MiraExt {
37 type Builder = MiraBuilder;
38
39 const VERIFY_PATH: &'static str = "/user-credits";
40
41 fn build<H>(
42 _: &crate::client::ClientBuilder<
43 Self::Builder,
44 <Self::Builder as crate::client::ProviderBuilder>::ApiKey,
45 H,
46 >,
47 ) -> http_client::Result<Self> {
48 Ok(Self)
49 }
50}
51
52impl<H> Capabilities<H> for MiraExt {
53 type Completion = Capable<CompletionModel<H>>;
54 type Embeddings = Nothing;
55 type Transcription = Nothing;
56 type ModelListing = Nothing;
57
58 #[cfg(feature = "image")]
59 type ImageGeneration = Nothing;
60
61 #[cfg(feature = "audio")]
62 type AudioGeneration = Nothing;
63}
64
65impl DebugExt for MiraExt {}
66
67impl ProviderBuilder for MiraBuilder {
68 type Output = MiraExt;
69 type ApiKey = MiraApiKey;
70
71 const BASE_URL: &'static str = MIRA_API_BASE_URL;
72}
73
74pub type Client<H = reqwest::Client> = client::Client<MiraExt, H>;
75pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<MiraBuilder, MiraApiKey, H>;
76
77#[derive(Debug, Error)]
78pub enum MiraError {
79 #[error("Invalid API key")]
80 InvalidApiKey,
81 #[error("API error: {0}")]
82 ApiError(u16),
83 #[error("Request error: {0}")]
84 RequestError(#[from] http_client::Error),
85 #[error("UTF-8 error: {0}")]
86 Utf8Error(#[from] FromUtf8Error),
87 #[error("JSON error: {0}")]
88 JsonError(#[from] serde_json::Error),
89}
90
91#[derive(Debug, Deserialize)]
92struct ApiErrorResponse {
93 message: String,
94}
95
96#[derive(Debug, Deserialize, Clone, Serialize)]
97pub struct RawMessage {
98 pub role: String,
99 pub content: String,
100}
101
102const MIRA_API_BASE_URL: &str = "https://api.mira.network";
103
104impl TryFrom<RawMessage> for message::Message {
105 type Error = CompletionError;
106
107 fn try_from(raw: RawMessage) -> Result<Self, Self::Error> {
108 match raw.role.as_str() {
109 "user" => Ok(message::Message::User {
110 content: OneOrMany::one(UserContent::Text(message::Text { text: raw.content })),
111 }),
112 "assistant" => Ok(message::Message::Assistant {
113 id: None,
114 content: OneOrMany::one(AssistantContent::Text(message::Text {
115 text: raw.content,
116 })),
117 }),
118 _ => Err(CompletionError::ResponseError(format!(
119 "Unsupported message role: {}",
120 raw.role
121 ))),
122 }
123 }
124}
125
126#[derive(Debug, Deserialize, Serialize)]
127#[serde(untagged)]
128pub enum CompletionResponse {
129 Structured {
130 id: String,
131 object: String,
132 created: u64,
133 model: String,
134 choices: Vec<ChatChoice>,
135 #[serde(skip_serializing_if = "Option::is_none")]
136 usage: Option<Usage>,
137 },
138 Simple(String),
139}
140
141#[derive(Debug, Deserialize, Serialize)]
142pub struct ChatChoice {
143 pub message: RawMessage,
144 #[serde(default)]
145 pub finish_reason: Option<String>,
146 #[serde(default)]
147 pub index: Option<usize>,
148}
149
150#[derive(Debug, Deserialize, Serialize)]
151struct ModelsResponse {
152 data: Vec<ModelInfo>,
153}
154
155#[derive(Debug, Deserialize, Serialize)]
156struct ModelInfo {
157 id: String,
158}
159
160impl<T> Client<T>
161where
162 T: HttpClientExt + 'static,
163{
164 pub async fn list_models(&self) -> Result<Vec<String>, MiraError> {
166 let req = self.get("/v1/models").and_then(|req| {
167 req.body(http_client::NoBody)
168 .map_err(http_client::Error::Protocol)
169 })?;
170
171 let response = self.send(req).await?;
172
173 let status = response.status();
174
175 if !status.is_success() {
176 let error_text = http_client::text(response).await.unwrap_or_default();
178 tracing::error!("Error response: {}", error_text);
179 return Err(MiraError::ApiError(status.as_u16()));
180 }
181
182 let response_text = http_client::text(response).await?;
183
184 let models: ModelsResponse = serde_json::from_str(&response_text).map_err(|e| {
185 tracing::error!("Failed to parse response: {}", e);
186 MiraError::JsonError(e)
187 })?;
188
189 Ok(models.data.into_iter().map(|model| model.id).collect())
190 }
191}
192
193impl ProviderClient for Client {
194 type Input = String;
195
196 fn from_env() -> Self {
199 let api_key = std::env::var("MIRA_API_KEY").expect("MIRA_API_KEY not set");
200 Self::new(&api_key).unwrap()
201 }
202
203 fn from_val(input: Self::Input) -> Self {
204 Self::new(&input).unwrap()
205 }
206}
207
208#[derive(Debug, Serialize, Deserialize)]
209pub(super) struct MiraCompletionRequest {
210 model: String,
211 pub messages: Vec<RawMessage>,
212 #[serde(skip_serializing_if = "Option::is_none")]
213 temperature: Option<f64>,
214 #[serde(skip_serializing_if = "Option::is_none")]
215 max_tokens: Option<u64>,
216 pub stream: bool,
217}
218
219impl TryFrom<(&str, CompletionRequest)> for MiraCompletionRequest {
220 type Error = CompletionError;
221
222 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
223 if req.output_schema.is_some() {
224 tracing::warn!("Structured outputs currently not supported for Mira");
225 }
226 let model = req.model.clone().unwrap_or_else(|| model.to_string());
227 let mut messages = Vec::new();
228
229 if let Some(content) = &req.preamble {
230 messages.push(RawMessage {
231 role: "user".to_string(),
232 content: content.to_string(),
233 });
234 }
235
236 if let Some(Message::User { content }) = req.normalized_documents() {
237 let text = content
238 .into_iter()
239 .filter_map(|doc| match doc {
240 UserContent::Document(Document {
241 data: DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data),
242 ..
243 }) => Some(data),
244 UserContent::Text(text) => Some(text.text),
245
246 _ => None,
248 })
249 .collect::<Vec<_>>()
250 .join("\n");
251
252 messages.push(RawMessage {
253 role: "user".to_string(),
254 content: text,
255 });
256 }
257
258 for msg in req.chat_history {
259 let (role, content) = match msg {
260 Message::User { content } => {
261 let text = content
262 .iter()
263 .map(|c| match c {
264 UserContent::Text(text) => &text.text,
265 _ => "",
266 })
267 .collect::<Vec<_>>()
268 .join("\n");
269 ("user", text)
270 }
271 Message::Assistant { content, .. } => {
272 let text = content
273 .iter()
274 .map(|c| match c {
275 AssistantContent::Text(text) => &text.text,
276 _ => "",
277 })
278 .collect::<Vec<_>>()
279 .join("\n");
280 ("assistant", text)
281 }
282 };
283 messages.push(RawMessage {
284 role: role.to_string(),
285 content,
286 });
287 }
288
289 Ok(Self {
290 model: model.to_string(),
291 messages,
292 temperature: req.temperature,
293 max_tokens: req.max_tokens,
294 stream: false,
295 })
296 }
297}
298
299#[derive(Clone)]
300pub struct CompletionModel<T = reqwest::Client> {
301 client: Client<T>,
302 pub model: String,
304}
305
306impl<T> CompletionModel<T> {
307 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
308 Self {
309 client,
310 model: model.into(),
311 }
312 }
313}
314
315impl<T> completion::CompletionModel for CompletionModel<T>
316where
317 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
318{
319 type Response = CompletionResponse;
320 type StreamingResponse = openai::StreamingCompletionResponse;
321
322 type Client = Client<T>;
323
324 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
325 Self::new(client.clone(), model)
326 }
327
328 async fn completion(
329 &self,
330 completion_request: CompletionRequest,
331 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
332 let span = if tracing::Span::current().is_disabled() {
333 info_span!(
334 target: "rig::completions",
335 "chat",
336 gen_ai.operation.name = "chat",
337 gen_ai.provider.name = "mira",
338 gen_ai.request.model = self.model,
339 gen_ai.system_instructions = tracing::field::Empty,
340 gen_ai.response.id = tracing::field::Empty,
341 gen_ai.response.model = tracing::field::Empty,
342 gen_ai.usage.output_tokens = tracing::field::Empty,
343 gen_ai.usage.input_tokens = tracing::field::Empty,
344 )
345 } else {
346 tracing::Span::current()
347 };
348
349 span.record("gen_ai.system_instructions", &completion_request.preamble);
350
351 if !completion_request.tools.is_empty() {
352 tracing::warn!(target: "rig::completions",
353 "Tool calls are not supported by Mira AI. {len} tools will be ignored.",
354 len = completion_request.tools.len()
355 );
356 }
357
358 if completion_request.tool_choice.is_some() {
359 tracing::warn!("WARNING: `tool_choice` not supported on Mira AI");
360 }
361
362 if completion_request.additional_params.is_some() {
363 tracing::warn!("WARNING: Additional parameters not supported on Mira AI");
364 }
365
366 let request = MiraCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
367
368 if tracing::enabled!(tracing::Level::TRACE) {
369 tracing::trace!(target: "rig::completions",
370 "Mira completion request: {}",
371 serde_json::to_string_pretty(&request)?
372 );
373 }
374
375 let body = serde_json::to_vec(&request)?;
376
377 let req = self
378 .client
379 .post("/v1/chat/completions")?
380 .body(body)
381 .map_err(http_client::Error::from)?;
382
383 let async_block = async move {
384 let response = self
385 .client
386 .send::<_, bytes::Bytes>(req)
387 .await
388 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
389
390 let status = response.status();
391 let response_body = response.into_body().into_future().await?.to_vec();
392
393 if !status.is_success() {
394 let status = status.as_u16();
395 let error_text = String::from_utf8_lossy(&response_body).to_string();
396 return Err(CompletionError::ProviderError(format!(
397 "API error: {status} - {error_text}"
398 )));
399 }
400
401 let response: CompletionResponse = serde_json::from_slice(&response_body)?;
402
403 if tracing::enabled!(tracing::Level::TRACE) {
404 tracing::trace!(target: "rig::completions",
405 "Mira completion response: {}",
406 serde_json::to_string_pretty(&response)?
407 );
408 }
409
410 if let CompletionResponse::Structured {
411 id, model, usage, ..
412 } = &response
413 {
414 let span = tracing::Span::current();
415 span.record("gen_ai.response.model_name", model);
416 span.record("gen_ai.response.id", id);
417 if let Some(usage) = usage {
418 span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
419 span.record(
420 "gen_ai.usage.output_tokens",
421 usage.total_tokens - usage.prompt_tokens,
422 );
423 }
424 }
425
426 response.try_into()
427 };
428
429 async_block.instrument(span).await
430 }
431
432 async fn stream(
433 &self,
434 completion_request: CompletionRequest,
435 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
436 let span = if tracing::Span::current().is_disabled() {
437 info_span!(
438 target: "rig::completions",
439 "chat_streaming",
440 gen_ai.operation.name = "chat_streaming",
441 gen_ai.provider.name = "mira",
442 gen_ai.request.model = self.model,
443 gen_ai.system_instructions = tracing::field::Empty,
444 gen_ai.response.id = tracing::field::Empty,
445 gen_ai.response.model = tracing::field::Empty,
446 gen_ai.usage.output_tokens = tracing::field::Empty,
447 gen_ai.usage.input_tokens = tracing::field::Empty,
448 )
449 } else {
450 tracing::Span::current()
451 };
452
453 span.record("gen_ai.system_instructions", &completion_request.preamble);
454
455 if !completion_request.tools.is_empty() {
456 tracing::warn!(target: "rig::completions",
457 "Tool calls are not supported by Mira AI. {len} tools will be ignored.",
458 len = completion_request.tools.len()
459 );
460 }
461
462 if completion_request.tool_choice.is_some() {
463 tracing::warn!("WARNING: `tool_choice` not supported on Mira AI");
464 }
465
466 if completion_request.additional_params.is_some() {
467 tracing::warn!("WARNING: Additional parameters not supported on Mira AI");
468 }
469 let mut request =
470 MiraCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
471 request.stream = true;
472
473 if tracing::enabled!(tracing::Level::TRACE) {
474 tracing::trace!(target: "rig::completions",
475 "Mira completion request: {}",
476 serde_json::to_string_pretty(&request)?
477 );
478 }
479
480 let body = serde_json::to_vec(&request)?;
481
482 let req = self
483 .client
484 .post("/v1/chat/completions")?
485 .body(body)
486 .map_err(http_client::Error::from)?;
487
488 send_compatible_streaming_request(self.client.clone(), req)
489 .instrument(span)
490 .await
491 }
492}
493
494impl From<ApiErrorResponse> for CompletionError {
495 fn from(err: ApiErrorResponse) -> Self {
496 CompletionError::ProviderError(err.message)
497 }
498}
499
500impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
501 type Error = CompletionError;
502
503 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
504 let (content, usage) = match &response {
505 CompletionResponse::Structured { choices, usage, .. } => {
506 let choice = choices.first().ok_or_else(|| {
507 CompletionError::ResponseError("Response contained no choices".to_owned())
508 })?;
509
510 let usage = usage
511 .as_ref()
512 .map(|usage| completion::Usage {
513 input_tokens: usage.prompt_tokens as u64,
514 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
515 total_tokens: usage.total_tokens as u64,
516 cached_input_tokens: 0,
517 })
518 .unwrap_or_default();
519
520 let message = message::Message::try_from(choice.message.clone())?;
522
523 let content = match message {
524 Message::Assistant { content, .. } => {
525 if content.is_empty() {
526 return Err(CompletionError::ResponseError(
527 "Response contained empty content".to_owned(),
528 ));
529 }
530
531 for c in content.iter() {
533 if !matches!(c, AssistantContent::Text(_)) {
534 tracing::warn!(target: "rig",
535 "Unsupported content type encountered: {:?}. The Mira provider currently only supports text content", c
536 );
537 }
538 }
539
540 content.iter().map(|c| {
541 match c {
542 AssistantContent::Text(text) => Ok(completion::AssistantContent::text(&text.text)),
543 other => Err(CompletionError::ResponseError(
544 format!("Unsupported content type: {other:?}. The Mira provider currently only supports text content")
545 ))
546 }
547 }).collect::<Result<Vec<_>, _>>()?
548 }
549 Message::User { .. } => {
550 tracing::warn!(target: "rig", "Received user message in response where assistant message was expected");
551 return Err(CompletionError::ResponseError(
552 "Received user message in response where assistant message was expected".to_owned()
553 ));
554 }
555 };
556
557 (content, usage)
558 }
559 CompletionResponse::Simple(text) => (
560 vec![completion::AssistantContent::text(text)],
561 completion::Usage::new(),
562 ),
563 };
564
565 let choice = OneOrMany::many(content).map_err(|_| {
566 CompletionError::ResponseError(
567 "Response contained no message or tool call (empty)".to_owned(),
568 )
569 })?;
570
571 Ok(completion::CompletionResponse {
572 choice,
573 usage,
574 raw_response: response,
575 message_id: None,
576 })
577 }
578}
579
580#[derive(Clone, Debug, Deserialize, Serialize)]
581pub struct Usage {
582 pub prompt_tokens: usize,
583 pub total_tokens: usize,
584}
585
586impl std::fmt::Display for Usage {
587 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
588 write!(
589 f,
590 "Prompt tokens: {} Total tokens: {}",
591 self.prompt_tokens, self.total_tokens
592 )
593 }
594}
595
596impl From<Message> for serde_json::Value {
597 fn from(msg: Message) -> Self {
598 match msg {
599 Message::User { content } => {
600 let text = content
601 .iter()
602 .map(|c| match c {
603 UserContent::Text(text) => &text.text,
604 _ => "",
605 })
606 .collect::<Vec<_>>()
607 .join("\n");
608 serde_json::json!({
609 "role": "user",
610 "content": text
611 })
612 }
613 Message::Assistant { content, .. } => {
614 let text = content
615 .iter()
616 .map(|c| match c {
617 AssistantContent::Text(text) => &text.text,
618 _ => "",
619 })
620 .collect::<Vec<_>>()
621 .join("\n");
622 serde_json::json!({
623 "role": "assistant",
624 "content": text
625 })
626 }
627 }
628 }
629}
630
631impl TryFrom<serde_json::Value> for Message {
632 type Error = CompletionError;
633
634 fn try_from(value: serde_json::Value) -> Result<Self, Self::Error> {
635 let role = value["role"].as_str().ok_or_else(|| {
636 CompletionError::ResponseError("Message missing role field".to_owned())
637 })?;
638
639 let content = match value.get("content") {
641 Some(content) => match content {
642 serde_json::Value::String(s) => s.clone(),
643 serde_json::Value::Array(arr) => arr
644 .iter()
645 .filter_map(|c| {
646 c.get("text")
647 .and_then(|t| t.as_str())
648 .map(|text| text.to_string())
649 })
650 .collect::<Vec<_>>()
651 .join("\n"),
652 _ => {
653 return Err(CompletionError::ResponseError(
654 "Message content must be string or array".to_owned(),
655 ));
656 }
657 },
658 None => {
659 return Err(CompletionError::ResponseError(
660 "Message missing content field".to_owned(),
661 ));
662 }
663 };
664
665 match role {
666 "user" => Ok(Message::User {
667 content: OneOrMany::one(UserContent::Text(message::Text { text: content })),
668 }),
669 "assistant" => Ok(Message::Assistant {
670 id: None,
671 content: OneOrMany::one(AssistantContent::Text(message::Text { text: content })),
672 }),
673 _ => Err(CompletionError::ResponseError(format!(
674 "Unsupported message role: {role}"
675 ))),
676 }
677 }
678}
679
680#[cfg(test)]
681mod tests {
682 use super::*;
683 use crate::message::UserContent;
684 use serde_json::json;
685
686 #[test]
687 fn test_deserialize_message() {
688 let assistant_message_json = json!({
690 "role": "assistant",
691 "content": "Hello there, how may I assist you today?"
692 });
693
694 let user_message_json = json!({
695 "role": "user",
696 "content": "What can you help me with?"
697 });
698
699 let assistant_message_array_json = json!({
701 "role": "assistant",
702 "content": [{
703 "type": "text",
704 "text": "Hello there, how may I assist you today?"
705 }]
706 });
707
708 let assistant_message = Message::try_from(assistant_message_json).unwrap();
709 let user_message = Message::try_from(user_message_json).unwrap();
710 let assistant_message_array = Message::try_from(assistant_message_array_json).unwrap();
711
712 match assistant_message {
714 Message::Assistant { content, .. } => {
715 assert_eq!(
716 content.first(),
717 AssistantContent::Text(message::Text {
718 text: "Hello there, how may I assist you today?".to_string()
719 })
720 );
721 }
722 _ => panic!("Expected assistant message"),
723 }
724
725 match user_message {
726 Message::User { content } => {
727 assert_eq!(
728 content.first(),
729 UserContent::Text(message::Text {
730 text: "What can you help me with?".to_string()
731 })
732 );
733 }
734 _ => panic!("Expected user message"),
735 }
736
737 match assistant_message_array {
739 Message::Assistant { content, .. } => {
740 assert_eq!(
741 content.first(),
742 AssistantContent::Text(message::Text {
743 text: "Hello there, how may I assist you today?".to_string()
744 })
745 );
746 }
747 _ => panic!("Expected assistant message"),
748 }
749 }
750
751 #[test]
752 fn test_message_conversion() {
753 let original_message = message::Message::User {
755 content: OneOrMany::one(message::UserContent::text("Hello")),
756 };
757
758 let mira_value: serde_json::Value = original_message.clone().into();
760
761 let converted_message: Message = mira_value.try_into().unwrap();
763
764 assert_eq!(original_message, converted_message);
765 }
766
767 #[test]
768 fn test_completion_response_conversion() {
769 let mira_response = CompletionResponse::Structured {
770 id: "resp_123".to_string(),
771 object: "chat.completion".to_string(),
772 created: 1234567890,
773 model: "deepseek-r1".to_string(),
774 choices: vec![ChatChoice {
775 message: RawMessage {
776 role: "assistant".to_string(),
777 content: "Test response".to_string(),
778 },
779 finish_reason: Some("stop".to_string()),
780 index: Some(0),
781 }],
782 usage: Some(Usage {
783 prompt_tokens: 10,
784 total_tokens: 20,
785 }),
786 };
787
788 let completion_response: completion::CompletionResponse<CompletionResponse> =
789 mira_response.try_into().unwrap();
790
791 assert_eq!(
792 completion_response.choice.first(),
793 completion::AssistantContent::text("Test response")
794 );
795 }
796 #[test]
797 fn test_client_initialization() {
798 let _client: crate::providers::mira::Client =
799 crate::providers::mira::Client::new("dummy-key").expect("Client::new() failed");
800 let _client_from_builder: crate::providers::mira::Client =
801 crate::providers::mira::Client::builder()
802 .api_key("dummy-key")
803 .build()
804 .expect("Client::builder() failed");
805 }
806}