1use crate::client::{
11 ClientBuilderError, CompletionClient, ProviderClient, VerifyClient, VerifyError,
12};
13use crate::json_utils::merge;
14use crate::providers::openai;
15use crate::providers::openai::send_compatible_streaming_request;
16use crate::streaming::StreamingCompletionResponse;
17use crate::{
18 OneOrMany,
19 completion::{self, CompletionError, CompletionRequest},
20 impl_conversion_traits,
21 message::{self, AssistantContent, Message, UserContent},
22};
23use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderValue};
24use serde::{Deserialize, Serialize};
25use serde_json::{Value, json};
26use std::string::FromUtf8Error;
27use thiserror::Error;
28use tracing;
29
30#[derive(Debug, Error)]
31pub enum MiraError {
32 #[error("Invalid API key")]
33 InvalidApiKey,
34 #[error("API error: {0}")]
35 ApiError(u16),
36 #[error("Request error: {0}")]
37 RequestError(#[from] reqwest::Error),
38 #[error("UTF-8 error: {0}")]
39 Utf8Error(#[from] FromUtf8Error),
40 #[error("JSON error: {0}")]
41 JsonError(#[from] serde_json::Error),
42}
43
44#[derive(Debug, Deserialize)]
45struct ApiErrorResponse {
46 message: String,
47}
48
49#[derive(Debug, Deserialize, Clone, Serialize)]
50pub struct RawMessage {
51 pub role: String,
52 pub content: String,
53}
54
55const MIRA_API_BASE_URL: &str = "https://api.mira.network";
56
57impl TryFrom<RawMessage> for message::Message {
58 type Error = CompletionError;
59
60 fn try_from(raw: RawMessage) -> Result<Self, Self::Error> {
61 match raw.role.as_str() {
62 "user" => Ok(message::Message::User {
63 content: OneOrMany::one(UserContent::Text(message::Text { text: raw.content })),
64 }),
65 "assistant" => Ok(message::Message::Assistant {
66 id: None,
67 content: OneOrMany::one(AssistantContent::Text(message::Text {
68 text: raw.content,
69 })),
70 }),
71 _ => Err(CompletionError::ResponseError(format!(
72 "Unsupported message role: {}",
73 raw.role
74 ))),
75 }
76 }
77}
78
79#[derive(Debug, Deserialize, Serialize)]
80#[serde(untagged)]
81pub enum CompletionResponse {
82 Structured {
83 id: String,
84 object: String,
85 created: u64,
86 model: String,
87 choices: Vec<ChatChoice>,
88 #[serde(skip_serializing_if = "Option::is_none")]
89 usage: Option<Usage>,
90 },
91 Simple(String),
92}
93
94#[derive(Debug, Deserialize, Serialize)]
95pub struct ChatChoice {
96 pub message: RawMessage,
97 #[serde(default)]
98 pub finish_reason: Option<String>,
99 #[serde(default)]
100 pub index: Option<usize>,
101}
102
103#[derive(Debug, Deserialize, Serialize)]
104struct ModelsResponse {
105 data: Vec<ModelInfo>,
106}
107
108#[derive(Debug, Deserialize, Serialize)]
109struct ModelInfo {
110 id: String,
111}
112
113pub struct ClientBuilder<'a> {
114 api_key: &'a str,
115 base_url: &'a str,
116 http_client: Option<reqwest::Client>,
117}
118
119impl<'a> ClientBuilder<'a> {
120 pub fn new(api_key: &'a str) -> Self {
121 Self {
122 api_key,
123 base_url: MIRA_API_BASE_URL,
124 http_client: None,
125 }
126 }
127
128 pub fn base_url(mut self, base_url: &'a str) -> Self {
129 self.base_url = base_url;
130 self
131 }
132
133 pub fn custom_client(mut self, client: reqwest::Client) -> Self {
134 self.http_client = Some(client);
135 self
136 }
137
138 pub fn build(self) -> Result<Client, ClientBuilderError> {
139 let mut headers = HeaderMap::new();
140 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
141 headers.insert(
142 reqwest::header::ACCEPT,
143 HeaderValue::from_static("application/json"),
144 );
145 headers.insert(
146 reqwest::header::USER_AGENT,
147 HeaderValue::from_static("rig-client/1.0"),
148 );
149 let http_client = if let Some(http_client) = self.http_client {
150 http_client
151 } else {
152 reqwest::Client::builder().build()?
153 };
154
155 Ok(Client {
156 base_url: self.base_url.to_string(),
157 http_client,
158 api_key: self.api_key.to_string(),
159 headers,
160 })
161 }
162}
163
164#[derive(Clone)]
165pub struct Client {
167 base_url: String,
168 http_client: reqwest::Client,
169 api_key: String,
170 headers: HeaderMap,
171}
172
173impl std::fmt::Debug for Client {
174 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
175 f.debug_struct("Client")
176 .field("base_url", &self.base_url)
177 .field("http_client", &self.http_client)
178 .field("api_key", &"<REDACTED>")
179 .field("headers", &self.headers)
180 .finish()
181 }
182}
183
184impl Client {
185 pub fn builder(api_key: &str) -> ClientBuilder<'_> {
196 ClientBuilder::new(api_key)
197 }
198
199 pub fn new(api_key: &str) -> Self {
204 Self::builder(api_key)
205 .build()
206 .expect("Mira client should build")
207 }
208
209 pub async fn list_models(&self) -> Result<Vec<String>, MiraError> {
211 let response = self.get("/v1/models").send().await?;
212
213 let status = response.status();
214
215 if !status.is_success() {
216 let _error_text = response.text().await.unwrap_or_default();
218 tracing::error!("Error response: {}", _error_text);
219 return Err(MiraError::ApiError(status.as_u16()));
220 }
221
222 let response_text = response.text().await?;
223
224 let models: ModelsResponse = serde_json::from_str(&response_text).map_err(|e| {
225 tracing::error!("Failed to parse response: {}", e);
226 MiraError::JsonError(e)
227 })?;
228
229 Ok(models.data.into_iter().map(|model| model.id).collect())
230 }
231
232 pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
233 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
234 self.http_client
235 .post(url)
236 .bearer_auth(&self.api_key)
237 .headers(self.headers.clone())
238 }
239
240 pub(crate) fn get(&self, path: &str) -> reqwest::RequestBuilder {
241 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
242 self.http_client
243 .get(url)
244 .bearer_auth(&self.api_key)
245 .headers(self.headers.clone())
246 }
247}
248
249impl ProviderClient for Client {
250 fn from_env() -> Self {
253 let api_key = std::env::var("MIRA_API_KEY").expect("MIRA_API_KEY not set");
254 Self::new(&api_key)
255 }
256
257 fn from_val(input: crate::client::ProviderValue) -> Self {
258 let crate::client::ProviderValue::Simple(api_key) = input else {
259 panic!("Incorrect provider value type")
260 };
261 Self::new(&api_key)
262 }
263}
264
265impl CompletionClient for Client {
266 type CompletionModel = CompletionModel;
267 fn completion_model(&self, model: &str) -> CompletionModel {
269 CompletionModel::new(self.to_owned(), model)
270 }
271}
272
273impl VerifyClient for Client {
274 #[cfg_attr(feature = "worker", worker::send)]
275 async fn verify(&self) -> Result<(), VerifyError> {
276 let response = self.get("/user-credits").send().await?;
277 match response.status() {
278 reqwest::StatusCode::OK => Ok(()),
279 reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
280 reqwest::StatusCode::INTERNAL_SERVER_ERROR => {
281 Err(VerifyError::ProviderError(response.text().await?))
282 }
283 _ => {
284 response.error_for_status()?;
285 Ok(())
286 }
287 }
288 }
289}
290
291impl_conversion_traits!(
292 AsEmbeddings,
293 AsTranscription,
294 AsImageGeneration,
295 AsAudioGeneration for Client
296);
297
298#[derive(Clone)]
299pub struct CompletionModel {
300 client: Client,
301 pub model: String,
303}
304
305impl CompletionModel {
306 pub fn new(client: Client, model: &str) -> Self {
307 Self {
308 client,
309 model: model.to_string(),
310 }
311 }
312
313 fn create_completion_request(
314 &self,
315 completion_request: CompletionRequest,
316 ) -> Result<Value, CompletionError> {
317 let mut messages = Vec::new();
318
319 if let Some(preamble) = &completion_request.preamble {
321 messages.push(serde_json::json!({
322 "role": "user",
323 "content": preamble.to_string()
324 }));
325 }
326
327 if let Some(Message::User { content }) = completion_request.normalized_documents() {
329 let text = content
330 .into_iter()
331 .filter_map(|doc| match doc {
332 UserContent::Document(doc) => Some(doc.data),
333 UserContent::Text(text) => Some(text.text),
334
335 _ => None,
337 })
338 .collect::<Vec<_>>()
339 .join("\n");
340
341 messages.push(serde_json::json!({
342 "role": "user",
343 "content": text
344 }));
345 }
346
347 for msg in completion_request.chat_history {
349 let (role, content) = match msg {
350 Message::User { content } => {
351 let text = content
352 .iter()
353 .map(|c| match c {
354 UserContent::Text(text) => &text.text,
355 _ => "",
356 })
357 .collect::<Vec<_>>()
358 .join("\n");
359 ("user", text)
360 }
361 Message::Assistant { content, .. } => {
362 let text = content
363 .iter()
364 .map(|c| match c {
365 AssistantContent::Text(text) => &text.text,
366 _ => "",
367 })
368 .collect::<Vec<_>>()
369 .join("\n");
370 ("assistant", text)
371 }
372 };
373 messages.push(serde_json::json!({
374 "role": role,
375 "content": content
376 }));
377 }
378
379 let request = serde_json::json!({
380 "model": self.model,
381 "messages": messages,
382 "temperature": completion_request.temperature.map(|t| t as f32).unwrap_or(0.7),
383 "max_tokens": completion_request.max_tokens.map(|t| t as u32).unwrap_or(100),
384 "stream": false
385 });
386
387 Ok(request)
388 }
389}
390
391impl completion::CompletionModel for CompletionModel {
392 type Response = CompletionResponse;
393 type StreamingResponse = openai::StreamingCompletionResponse;
394
395 #[cfg_attr(feature = "worker", worker::send)]
396 async fn completion(
397 &self,
398 completion_request: CompletionRequest,
399 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
400 if !completion_request.tools.is_empty() {
401 tracing::warn!(target: "rig",
402 "Tool calls are not supported by the Mira provider. {} tools will be ignored.",
403 completion_request.tools.len()
404 );
405 }
406
407 let mira_request = self.create_completion_request(completion_request)?;
408
409 let response = self
410 .client
411 .post("/v1/chat/completions")
412 .json(&mira_request)
413 .send()
414 .await
415 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
416
417 if !response.status().is_success() {
418 let status = response.status().as_u16();
419 let error_text = response.text().await.unwrap_or_default();
420 return Err(CompletionError::ProviderError(format!(
421 "API error: {status} - {error_text}"
422 )));
423 }
424
425 let response: CompletionResponse = response
426 .json()
427 .await
428 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
429
430 response.try_into()
431 }
432
433 #[cfg_attr(feature = "worker", worker::send)]
434 async fn stream(
435 &self,
436 completion_request: CompletionRequest,
437 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
438 let mut request = self.create_completion_request(completion_request)?;
439
440 request = merge(request, json!({"stream": true}));
441
442 let builder = self.client.post("/v1/chat/completions").json(&request);
443
444 send_compatible_streaming_request(builder).await
445 }
446}
447
448impl From<ApiErrorResponse> for CompletionError {
449 fn from(err: ApiErrorResponse) -> Self {
450 CompletionError::ProviderError(err.message)
451 }
452}
453
454impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
455 type Error = CompletionError;
456
457 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
458 let (content, usage) = match &response {
459 CompletionResponse::Structured { choices, usage, .. } => {
460 let choice = choices.first().ok_or_else(|| {
461 CompletionError::ResponseError("Response contained no choices".to_owned())
462 })?;
463
464 let usage = usage
465 .as_ref()
466 .map(|usage| completion::Usage {
467 input_tokens: usage.prompt_tokens as u64,
468 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
469 total_tokens: usage.total_tokens as u64,
470 })
471 .unwrap_or_default();
472
473 let message = message::Message::try_from(choice.message.clone())?;
475
476 let content = match message {
477 Message::Assistant { content, .. } => {
478 if content.is_empty() {
479 return Err(CompletionError::ResponseError(
480 "Response contained empty content".to_owned(),
481 ));
482 }
483
484 for c in content.iter() {
486 if !matches!(c, AssistantContent::Text(_)) {
487 tracing::warn!(target: "rig",
488 "Unsupported content type encountered: {:?}. The Mira provider currently only supports text content", c
489 );
490 }
491 }
492
493 content.iter().map(|c| {
494 match c {
495 AssistantContent::Text(text) => Ok(completion::AssistantContent::text(&text.text)),
496 other => Err(CompletionError::ResponseError(
497 format!("Unsupported content type: {other:?}. The Mira provider currently only supports text content")
498 ))
499 }
500 }).collect::<Result<Vec<_>, _>>()?
501 }
502 Message::User { .. } => {
503 tracing::warn!(target: "rig", "Received user message in response where assistant message was expected");
504 return Err(CompletionError::ResponseError(
505 "Received user message in response where assistant message was expected".to_owned()
506 ));
507 }
508 };
509
510 (content, usage)
511 }
512 CompletionResponse::Simple(text) => (
513 vec![completion::AssistantContent::text(text)],
514 completion::Usage::new(),
515 ),
516 };
517
518 let choice = OneOrMany::many(content).map_err(|_| {
519 CompletionError::ResponseError(
520 "Response contained no message or tool call (empty)".to_owned(),
521 )
522 })?;
523
524 Ok(completion::CompletionResponse {
525 choice,
526 usage,
527 raw_response: response,
528 })
529 }
530}
531
532#[derive(Clone, Debug, Deserialize, Serialize)]
533pub struct Usage {
534 pub prompt_tokens: usize,
535 pub total_tokens: usize,
536}
537
538impl std::fmt::Display for Usage {
539 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
540 write!(
541 f,
542 "Prompt tokens: {} Total tokens: {}",
543 self.prompt_tokens, self.total_tokens
544 )
545 }
546}
547
548impl From<Message> for serde_json::Value {
549 fn from(msg: Message) -> Self {
550 match msg {
551 Message::User { content } => {
552 let text = content
553 .iter()
554 .map(|c| match c {
555 UserContent::Text(text) => &text.text,
556 _ => "",
557 })
558 .collect::<Vec<_>>()
559 .join("\n");
560 serde_json::json!({
561 "role": "user",
562 "content": text
563 })
564 }
565 Message::Assistant { content, .. } => {
566 let text = content
567 .iter()
568 .map(|c| match c {
569 AssistantContent::Text(text) => &text.text,
570 _ => "",
571 })
572 .collect::<Vec<_>>()
573 .join("\n");
574 serde_json::json!({
575 "role": "assistant",
576 "content": text
577 })
578 }
579 }
580 }
581}
582
583impl TryFrom<serde_json::Value> for Message {
584 type Error = CompletionError;
585
586 fn try_from(value: serde_json::Value) -> Result<Self, Self::Error> {
587 let role = value["role"].as_str().ok_or_else(|| {
588 CompletionError::ResponseError("Message missing role field".to_owned())
589 })?;
590
591 let content = match value.get("content") {
593 Some(content) => match content {
594 serde_json::Value::String(s) => s.clone(),
595 serde_json::Value::Array(arr) => arr
596 .iter()
597 .filter_map(|c| {
598 c.get("text")
599 .and_then(|t| t.as_str())
600 .map(|text| text.to_string())
601 })
602 .collect::<Vec<_>>()
603 .join("\n"),
604 _ => {
605 return Err(CompletionError::ResponseError(
606 "Message content must be string or array".to_owned(),
607 ));
608 }
609 },
610 None => {
611 return Err(CompletionError::ResponseError(
612 "Message missing content field".to_owned(),
613 ));
614 }
615 };
616
617 match role {
618 "user" => Ok(Message::User {
619 content: OneOrMany::one(UserContent::Text(message::Text { text: content })),
620 }),
621 "assistant" => Ok(Message::Assistant {
622 id: None,
623 content: OneOrMany::one(AssistantContent::Text(message::Text { text: content })),
624 }),
625 _ => Err(CompletionError::ResponseError(format!(
626 "Unsupported message role: {role}"
627 ))),
628 }
629 }
630}
631
632#[cfg(test)]
633mod tests {
634 use super::*;
635 use crate::message::UserContent;
636 use serde_json::json;
637
638 #[test]
639 fn test_deserialize_message() {
640 let assistant_message_json = json!({
642 "role": "assistant",
643 "content": "Hello there, how may I assist you today?"
644 });
645
646 let user_message_json = json!({
647 "role": "user",
648 "content": "What can you help me with?"
649 });
650
651 let assistant_message_array_json = json!({
653 "role": "assistant",
654 "content": [{
655 "type": "text",
656 "text": "Hello there, how may I assist you today?"
657 }]
658 });
659
660 let assistant_message = Message::try_from(assistant_message_json).unwrap();
661 let user_message = Message::try_from(user_message_json).unwrap();
662 let assistant_message_array = Message::try_from(assistant_message_array_json).unwrap();
663
664 match assistant_message {
666 Message::Assistant { content, .. } => {
667 assert_eq!(
668 content.first(),
669 AssistantContent::Text(message::Text {
670 text: "Hello there, how may I assist you today?".to_string()
671 })
672 );
673 }
674 _ => panic!("Expected assistant message"),
675 }
676
677 match user_message {
678 Message::User { content } => {
679 assert_eq!(
680 content.first(),
681 UserContent::Text(message::Text {
682 text: "What can you help me with?".to_string()
683 })
684 );
685 }
686 _ => panic!("Expected user message"),
687 }
688
689 match assistant_message_array {
691 Message::Assistant { content, .. } => {
692 assert_eq!(
693 content.first(),
694 AssistantContent::Text(message::Text {
695 text: "Hello there, how may I assist you today?".to_string()
696 })
697 );
698 }
699 _ => panic!("Expected assistant message"),
700 }
701 }
702
703 #[test]
704 fn test_message_conversion() {
705 let original_message = message::Message::User {
707 content: OneOrMany::one(message::UserContent::text("Hello")),
708 };
709
710 let mira_value: serde_json::Value = original_message.clone().into();
712
713 let converted_message: Message = mira_value.try_into().unwrap();
715
716 assert_eq!(original_message, converted_message);
717 }
718
719 #[test]
720 fn test_completion_response_conversion() {
721 let mira_response = CompletionResponse::Structured {
722 id: "resp_123".to_string(),
723 object: "chat.completion".to_string(),
724 created: 1234567890,
725 model: "deepseek-r1".to_string(),
726 choices: vec![ChatChoice {
727 message: RawMessage {
728 role: "assistant".to_string(),
729 content: "Test response".to_string(),
730 },
731 finish_reason: Some("stop".to_string()),
732 index: Some(0),
733 }],
734 usage: Some(Usage {
735 prompt_tokens: 10,
736 total_tokens: 20,
737 }),
738 };
739
740 let completion_response: completion::CompletionResponse<CompletionResponse> =
741 mira_response.try_into().unwrap();
742
743 assert_eq!(
744 completion_response.choice.first(),
745 completion::AssistantContent::text("Test response")
746 );
747 }
748}