inference_gateway_sdk/
lib.rs

1//! Inference Gateway SDK for Rust
2//!
3//! This crate provides a Rust client for the Inference Gateway API, allowing interaction
4//! with various LLM providers through a unified interface.
5
6use core::fmt;
7use std::future::Future;
8
9use futures_util::{Stream, StreamExt};
10use reqwest::{Client, StatusCode};
11use serde::{Deserialize, Serialize};
12
13use serde_json::Value;
14use thiserror::Error;
15
16/// Stream of Server-Sent Events (SSE) from the Inference Gateway API
17#[derive(Debug, Serialize, Deserialize)]
18pub struct SSEvents {
19    pub data: String,
20    pub event: Option<String>,
21    pub retry: Option<u64>,
22}
23
24/// Custom error types for the Inference Gateway SDK
25#[derive(Error, Debug)]
26pub enum GatewayError {
27    #[error("Unauthorized: {0}")]
28    Unauthorized(String),
29
30    #[error("Forbidden: {0}")]
31    Forbidden(String),
32
33    #[error("Not found: {0}")]
34    NotFound(String),
35
36    #[error("Bad request: {0}")]
37    BadRequest(String),
38
39    #[error("Internal server error: {0}")]
40    InternalError(String),
41
42    #[error("Stream error: {0}")]
43    StreamError(reqwest::Error),
44
45    #[error("Decoding error: {0}")]
46    DecodingError(std::string::FromUtf8Error),
47
48    #[error("Request error: {0}")]
49    RequestError(#[from] reqwest::Error),
50
51    #[error("Deserialization error: {0}")]
52    DeserializationError(serde_json::Error),
53
54    #[error("Serialization error: {0}")]
55    SerializationError(#[from] serde_json::Error),
56
57    #[error("Other error: {0}")]
58    Other(#[from] Box<dyn std::error::Error + Send + Sync>),
59}
60
61#[derive(Debug, Deserialize)]
62struct ErrorResponse {
63    error: String,
64}
65
66/// Common model information
67#[derive(Debug, Serialize, Deserialize, Clone)]
68pub struct Model {
69    /// The model identifier
70    pub id: String,
71    /// The object type, usually "model"
72    pub object: String,
73    /// The Unix timestamp (in seconds) of when the model was created
74    pub created: i64,
75    /// The organization that owns the model
76    pub owned_by: String,
77    /// The provider that serves the model
78    pub served_by: Provider,
79}
80
81/// Response structure for listing models
82#[derive(Debug, Serialize, Deserialize)]
83pub struct ListModelsResponse {
84    /// The provider identifier
85    #[serde(skip_serializing_if = "Option::is_none")]
86    pub provider: Option<Provider>,
87    /// Response object type, usually "list"
88    pub object: String,
89    /// List of available models
90    pub data: Vec<Model>,
91}
92
93/// An MCP tool definition
94#[derive(Debug, Serialize, Deserialize, Clone)]
95pub struct MCPTool {
96    /// The name of the tool
97    pub name: String,
98    /// A description of what the tool does
99    pub description: String,
100    /// The MCP server that provides this tool
101    pub server: String,
102    /// JSON schema for the tool's input parameters (optional)
103    #[serde(skip_serializing_if = "Option::is_none")]
104    pub input_schema: Option<Value>,
105}
106
107/// Response structure for listing MCP tools
108#[derive(Debug, Serialize, Deserialize)]
109pub struct ListToolsResponse {
110    /// Response object type, always "list"
111    pub object: String,
112    /// Array of available MCP tools
113    pub data: Vec<MCPTool>,
114}
115
116/// Supported LLM providers
117#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, Copy)]
118#[serde(rename_all = "lowercase")]
119pub enum Provider {
120    #[serde(alias = "Ollama", alias = "OLLAMA")]
121    Ollama,
122    #[serde(alias = "OllamaCloud", alias = "OLLAMA_CLOUD", rename = "ollama_cloud")]
123    OllamaCloud,
124    #[serde(alias = "Groq", alias = "GROQ")]
125    Groq,
126    #[serde(alias = "OpenAI", alias = "OPENAI")]
127    OpenAI,
128    #[serde(alias = "Cloudflare", alias = "CLOUDFLARE")]
129    Cloudflare,
130    #[serde(alias = "Cohere", alias = "COHERE")]
131    Cohere,
132    #[serde(alias = "Anthropic", alias = "ANTHROPIC")]
133    Anthropic,
134    #[serde(alias = "Deepseek", alias = "DEEPSEEK")]
135    Deepseek,
136    #[serde(alias = "Google", alias = "GOOGLE")]
137    Google,
138    #[serde(alias = "Mistral", alias = "MISTRAL")]
139    Mistral,
140}
141
142impl fmt::Display for Provider {
143    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
144        match self {
145            Provider::Ollama => write!(f, "ollama"),
146            Provider::OllamaCloud => write!(f, "ollama_cloud"),
147            Provider::Groq => write!(f, "groq"),
148            Provider::OpenAI => write!(f, "openai"),
149            Provider::Cloudflare => write!(f, "cloudflare"),
150            Provider::Cohere => write!(f, "cohere"),
151            Provider::Anthropic => write!(f, "anthropic"),
152            Provider::Deepseek => write!(f, "deepseek"),
153            Provider::Google => write!(f, "google"),
154            Provider::Mistral => write!(f, "mistral"),
155        }
156    }
157}
158
159impl TryFrom<&str> for Provider {
160    type Error = GatewayError;
161
162    fn try_from(s: &str) -> Result<Self, Self::Error> {
163        match s.to_lowercase().as_str() {
164            "ollama" => Ok(Self::Ollama),
165            "ollama_cloud" => Ok(Self::OllamaCloud),
166            "groq" => Ok(Self::Groq),
167            "openai" => Ok(Self::OpenAI),
168            "cloudflare" => Ok(Self::Cloudflare),
169            "cohere" => Ok(Self::Cohere),
170            "anthropic" => Ok(Self::Anthropic),
171            "deepseek" => Ok(Self::Deepseek),
172            "google" => Ok(Self::Google),
173            "mistral" => Ok(Self::Mistral),
174            _ => Err(GatewayError::BadRequest(format!("Unknown provider: {s}"))),
175        }
176    }
177}
178
179#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
180#[serde(rename_all = "lowercase")]
181pub enum MessageRole {
182    System,
183    #[default]
184    User,
185    Assistant,
186    Tool,
187}
188
189impl fmt::Display for MessageRole {
190    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
191        match self {
192            MessageRole::System => write!(f, "system"),
193            MessageRole::User => write!(f, "user"),
194            MessageRole::Assistant => write!(f, "assistant"),
195            MessageRole::Tool => write!(f, "tool"),
196        }
197    }
198}
199
200/// A message in a conversation with an LLM
201#[derive(Debug, Serialize, Deserialize, Clone, Default)]
202pub struct Message {
203    /// Role of the message sender ("system", "user", "assistant" or "tool")
204    pub role: MessageRole,
205    /// Content of the message
206    pub content: String,
207    /// The tools an LLM wants to invoke
208    #[serde(skip_serializing_if = "Option::is_none")]
209    pub tool_calls: Option<Vec<ChatCompletionMessageToolCall>>,
210    /// Unique identifier of the tool call
211    #[serde(skip_serializing_if = "Option::is_none")]
212    pub tool_call_id: Option<String>,
213    /// The reasoning content of the message
214    #[serde(skip_serializing_if = "Option::is_none")]
215    pub reasoning_content: Option<String>,
216    /// The reasoning of the message (same as reasoning_content)
217    #[serde(skip_serializing_if = "Option::is_none")]
218    pub reasoning: Option<String>,
219}
220
221/// A tool call in a message response
222#[derive(Debug, Deserialize, Serialize, Clone)]
223pub struct ChatCompletionMessageToolCall {
224    /// Unique identifier of the tool call
225    pub id: String,
226    /// Type of the tool being called
227    #[serde(rename = "type")]
228    pub r#type: ChatCompletionToolType,
229    /// Function that was called
230    pub function: ChatCompletionMessageToolCallFunction,
231}
232
233/// Type of tool that can be called
234#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
235pub enum ChatCompletionToolType {
236    /// Function tool type
237    #[serde(rename = "function")]
238    Function,
239}
240
241/// Function details in a tool call
242#[derive(Debug, Deserialize, Serialize, Clone)]
243pub struct ChatCompletionMessageToolCallFunction {
244    /// Name of the function to call
245    pub name: String,
246    /// Arguments to the function in JSON string format
247    pub arguments: String,
248}
249
250// Add this helper method to make argument access more convenient
251impl ChatCompletionMessageToolCallFunction {
252    pub fn parse_arguments(&self) -> Result<serde_json::Value, serde_json::Error> {
253        serde_json::from_str(&self.arguments)
254    }
255}
256
257/// Tool function to call
258#[derive(Debug, Serialize, Deserialize, Clone)]
259pub struct FunctionObject {
260    pub name: String,
261    pub description: String,
262    pub parameters: Value,
263}
264
265/// Type of tool that can be used by the model
266#[derive(Debug, Serialize, Deserialize, Clone)]
267#[serde(rename_all = "lowercase")]
268pub enum ToolType {
269    Function,
270}
271
272/// Tool to use for the LLM toolbox
273#[derive(Debug, Serialize, Deserialize, Clone)]
274pub struct Tool {
275    pub r#type: ToolType,
276    pub function: FunctionObject,
277}
278
279/// Request payload for generating content
280#[derive(Debug, Serialize)]
281struct CreateChatCompletionRequest {
282    /// Name of the model
283    model: String,
284    /// Conversation history and prompt
285    messages: Vec<Message>,
286    /// Enable streaming of responses
287    stream: bool,
288    /// Optional tools to use for generation
289    #[serde(skip_serializing_if = "Option::is_none")]
290    tools: Option<Vec<Tool>>,
291    /// Maximum number of tokens to generate
292    #[serde(skip_serializing_if = "Option::is_none")]
293    max_tokens: Option<i32>,
294    /// The format of the reasoning content. Can be `raw` or `parsed`.
295    #[serde(skip_serializing_if = "Option::is_none")]
296    reasoning_format: Option<String>,
297}
298
299/// A tool call chunk in streaming responses
300#[derive(Debug, Serialize, Deserialize, Clone)]
301pub struct ChatCompletionMessageToolCallChunk {
302    /// Index of the tool call in the array
303    pub index: i32,
304    /// Unique identifier of the tool call
305    #[serde(skip_serializing_if = "Option::is_none")]
306    pub id: Option<String>,
307    /// Type of tool that was called
308    #[serde(rename = "type", skip_serializing_if = "Option::is_none")]
309    pub r#type: Option<String>,
310    /// Function that the LLM wants to call
311    #[serde(skip_serializing_if = "Option::is_none")]
312    pub function: Option<ChatCompletionMessageToolCallFunction>,
313}
314
315/// The reason the model stopped generating tokens
316#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
317#[serde(rename_all = "snake_case")]
318pub enum FinishReason {
319    /// Model hit a natural stop point or a provided stop sequence
320    Stop,
321    /// Maximum number of tokens specified in the request was reached
322    Length,
323    /// Model called a tool
324    ToolCalls,
325    /// Content was omitted due to a flag from content filters
326    ContentFilter,
327    /// Function call (deprecated, use tool_calls)
328    FunctionCall,
329}
330
331#[derive(Debug, Deserialize, Clone)]
332pub struct ChatCompletionChoice {
333    pub finish_reason: FinishReason,
334    pub message: Message,
335    pub index: i32,
336    /// Log probability information for the choice
337    pub logprobs: Option<ChoiceLogprobs>,
338}
339
340/// The response from generating content
341#[derive(Debug, Deserialize, Clone)]
342pub struct CreateChatCompletionResponse {
343    pub id: String,
344    pub choices: Vec<ChatCompletionChoice>,
345    pub created: i64,
346    pub model: String,
347    pub object: String,
348}
349
350/// The response from streaming content generation
351#[derive(Debug, Deserialize, Clone)]
352pub struct CreateChatCompletionStreamResponse {
353    /// A unique identifier for the chat completion. Each chunk has the same ID.
354    pub id: String,
355    /// A list of chat completion choices. Can contain more than one element if `n` is greater than 1.
356    pub choices: Vec<ChatCompletionStreamChoice>,
357    /// The Unix timestamp (in seconds) of when the chat completion was created.
358    pub created: i64,
359    /// The model used to generate the completion.
360    pub model: String,
361    /// This fingerprint represents the backend configuration that the model runs with.
362    #[serde(skip_serializing_if = "Option::is_none")]
363    pub system_fingerprint: Option<String>,
364    /// The object type, which is always "chat.completion.chunk".
365    pub object: String,
366    /// Usage statistics for the completion request.
367    #[serde(skip_serializing_if = "Option::is_none")]
368    pub usage: Option<CompletionUsage>,
369    /// The format of the reasoning content. Can be `raw` or `parsed`.
370    #[serde(skip_serializing_if = "Option::is_none")]
371    pub reasoning_format: Option<String>,
372}
373
374/// Token log probability information
375#[derive(Debug, Deserialize, Clone)]
376pub struct ChatCompletionTokenLogprob {
377    /// The token
378    pub token: String,
379    /// The log probability of this token
380    pub logprob: f64,
381    /// UTF-8 bytes representation of the token
382    pub bytes: Option<Vec<i32>>,
383    /// List of the most likely tokens and their log probability
384    pub top_logprobs: Vec<TopLogprob>,
385}
386
387/// Top log probability entry
388#[derive(Debug, Deserialize, Clone)]
389pub struct TopLogprob {
390    /// The token
391    pub token: String,
392    /// The log probability of this token
393    pub logprob: f64,
394    /// UTF-8 bytes representation of the token
395    pub bytes: Option<Vec<i32>>,
396}
397
398/// Log probability information for a choice
399#[derive(Debug, Deserialize, Clone)]
400pub struct ChoiceLogprobs {
401    /// A list of message content tokens with log probability information
402    pub content: Option<Vec<ChatCompletionTokenLogprob>>,
403    /// A list of message refusal tokens with log probability information
404    pub refusal: Option<Vec<ChatCompletionTokenLogprob>>,
405}
406
407/// Choice in a streaming completion response
408#[derive(Debug, Deserialize, Clone)]
409pub struct ChatCompletionStreamChoice {
410    /// The delta content for this streaming chunk
411    pub delta: ChatCompletionStreamDelta,
412    /// Index of the choice in the choices array
413    pub index: i32,
414    /// The reason the model stopped generating tokens
415    #[serde(skip_serializing_if = "Option::is_none")]
416    pub finish_reason: Option<FinishReason>,
417    /// Log probability information for the choice
418    #[serde(skip_serializing_if = "Option::is_none")]
419    pub logprobs: Option<ChoiceLogprobs>,
420}
421
422/// Delta content for streaming responses
423#[derive(Debug, Deserialize, Clone)]
424pub struct ChatCompletionStreamDelta {
425    /// Role of the message sender
426    #[serde(skip_serializing_if = "Option::is_none")]
427    pub role: Option<MessageRole>,
428    /// Content of the message delta
429    #[serde(skip_serializing_if = "Option::is_none")]
430    pub content: Option<String>,
431    /// The reasoning content of the chunk message
432    #[serde(skip_serializing_if = "Option::is_none")]
433    pub reasoning_content: Option<String>,
434    /// The reasoning of the chunk message (same as reasoning_content)
435    #[serde(skip_serializing_if = "Option::is_none")]
436    pub reasoning: Option<String>,
437    /// Tool calls for this delta
438    #[serde(skip_serializing_if = "Option::is_none")]
439    pub tool_calls: Option<Vec<ChatCompletionMessageToolCallChunk>>,
440    /// The refusal message generated by the model
441    #[serde(skip_serializing_if = "Option::is_none")]
442    pub refusal: Option<String>,
443}
444
445/// Usage statistics for the completion request
446#[derive(Debug, Deserialize, Clone)]
447pub struct CompletionUsage {
448    /// Number of tokens in the generated completion
449    pub completion_tokens: i64,
450    /// Number of tokens in the prompt
451    pub prompt_tokens: i64,
452    /// Total number of tokens used in the request (prompt + completion)
453    pub total_tokens: i64,
454}
455
456/// Client for interacting with the Inference Gateway API
457pub struct InferenceGatewayClient {
458    base_url: String,
459    client: Client,
460    token: Option<String>,
461    tools: Option<Vec<Tool>>,
462    max_tokens: Option<i32>,
463}
464
465/// Implement Debug for InferenceGatewayClient
466impl std::fmt::Debug for InferenceGatewayClient {
467    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
468        f.debug_struct("InferenceGatewayClient")
469            .field("base_url", &self.base_url)
470            .field("token", &self.token.as_ref().map(|_| "*****"))
471            .finish()
472    }
473}
474
475/// Core API interface for the Inference Gateway
476pub trait InferenceGatewayAPI {
477    /// Lists available models from all providers
478    ///
479    /// # Errors
480    /// - Returns [`GatewayError::Unauthorized`] if authentication fails
481    /// - Returns [`GatewayError::BadRequest`] if the request is malformed
482    /// - Returns [`GatewayError::InternalError`] if the server has an error
483    /// - Returns [`GatewayError::Other`] for other errors
484    ///
485    /// # Returns
486    /// A list of models available from all providers
487    fn list_models(&self) -> impl Future<Output = Result<ListModelsResponse, GatewayError>> + Send;
488
489    /// Lists available models by a specific provider
490    ///
491    /// # Arguments
492    /// * `provider` - The LLM provider to list models for
493    ///
494    /// # Errors
495    /// - Returns [`GatewayError::Unauthorized`] if authentication fails
496    /// - Returns [`GatewayError::BadRequest`] if the request is malformed
497    /// - Returns [`GatewayError::InternalError`] if the server has an error
498    /// - Returns [`GatewayError::Other`] for other errors
499    ///
500    /// # Returns
501    /// A list of models available from the specified provider
502    fn list_models_by_provider(
503        &self,
504        provider: Provider,
505    ) -> impl Future<Output = Result<ListModelsResponse, GatewayError>> + Send;
506
507    /// Generates content using a specified model
508    ///
509    /// # Arguments
510    /// * `provider` - The LLM provider to use
511    /// * `model` - Name of the model
512    /// * `messages` - Conversation history and prompt
513    /// * `tools` - Optional tools to use for generation
514    ///
515    /// # Errors
516    /// - Returns [`GatewayError::Unauthorized`] if authentication fails
517    /// - Returns [`GatewayError::BadRequest`] if the request is malformed
518    /// - Returns [`GatewayError::InternalError`] if the server has an error
519    /// - Returns [`GatewayError::Other`] for other errors
520    ///
521    /// # Returns
522    /// The generated response
523    fn generate_content(
524        &self,
525        provider: Provider,
526        model: &str,
527        messages: Vec<Message>,
528    ) -> impl Future<Output = Result<CreateChatCompletionResponse, GatewayError>> + Send;
529
530    /// Stream content generation directly using the backend SSE stream.
531    ///
532    /// # Arguments
533    /// * `provider` - The LLM provider to use
534    /// * `model` - Name of the model
535    /// * `messages` - Conversation history and prompt
536    ///
537    /// # Returns
538    /// A stream of Server-Sent Events (SSE) from the Inference Gateway API
539    fn generate_content_stream(
540        &self,
541        provider: Provider,
542        model: &str,
543        messages: Vec<Message>,
544    ) -> impl Stream<Item = Result<SSEvents, GatewayError>> + Send;
545
546    /// Lists available MCP tools
547    ///
548    /// # Errors
549    /// - Returns [`GatewayError::Unauthorized`] if authentication fails
550    /// - Returns [`GatewayError::BadRequest`] if the request is malformed
551    /// - Returns [`GatewayError::InternalError`] if the server has an error
552    /// - Returns [`GatewayError::Other`] for other errors
553    ///
554    /// # Returns
555    /// A list of available MCP tools. Only accessible when EXPOSE_MCP is enabled.
556    fn list_tools(&self) -> impl Future<Output = Result<ListToolsResponse, GatewayError>> + Send;
557
558    /// Checks if the API is available
559    fn health_check(&self) -> impl Future<Output = Result<bool, GatewayError>> + Send;
560}
561
562impl InferenceGatewayClient {
563    /// Creates a new client instance
564    ///
565    /// # Arguments
566    /// * `base_url` - Base URL of the Inference Gateway API
567    pub fn new(base_url: &str) -> Self {
568        Self {
569            base_url: base_url.to_string(),
570            client: Client::new(),
571            token: None,
572            tools: None,
573            max_tokens: None,
574        }
575    }
576
577    /// Creates a new client instance with default configuration
578    /// pointing to http://localhost:8080/v1
579    pub fn new_default() -> Self {
580        let base_url = std::env::var("INFERENCE_GATEWAY_URL")
581            .unwrap_or_else(|_| "http://localhost:8080/v1".to_string());
582
583        Self {
584            base_url,
585            client: Client::new(),
586            token: None,
587            tools: None,
588            max_tokens: None,
589        }
590    }
591
592    /// Returns the base URL this client is configured to use
593    pub fn base_url(&self) -> &str {
594        &self.base_url
595    }
596
597    /// Provides tools to use for generation
598    ///
599    /// # Arguments
600    /// * `tools` - List of tools to use for generation
601    ///
602    /// # Returns
603    /// Self with the tools set
604    pub fn with_tools(mut self, tools: Option<Vec<Tool>>) -> Self {
605        self.tools = tools;
606        self
607    }
608
609    /// Sets an authentication token for the client
610    ///
611    /// # Arguments
612    /// * `token` - JWT token for authentication
613    ///
614    /// # Returns
615    /// Self with the token set
616    pub fn with_token(mut self, token: impl Into<String>) -> Self {
617        self.token = Some(token.into());
618        self
619    }
620
621    /// Sets the maximum number of tokens to generate
622    ///
623    /// # Arguments
624    /// * `max_tokens` - Maximum number of tokens to generate
625    ///
626    /// # Returns
627    /// Self with the maximum tokens set
628    pub fn with_max_tokens(mut self, max_tokens: Option<i32>) -> Self {
629        self.max_tokens = max_tokens;
630        self
631    }
632}
633
634impl InferenceGatewayAPI for InferenceGatewayClient {
635    async fn list_models(&self) -> Result<ListModelsResponse, GatewayError> {
636        let url = format!("{}/models", self.base_url);
637        let mut request = self.client.get(&url);
638        if let Some(token) = &self.token {
639            request = request.bearer_auth(token);
640        }
641
642        let response = request.send().await?;
643        match response.status() {
644            StatusCode::OK => {
645                let json_response: ListModelsResponse = response.json().await?;
646                Ok(json_response)
647            }
648            StatusCode::UNAUTHORIZED => {
649                let error: ErrorResponse = response.json().await?;
650                Err(GatewayError::Unauthorized(error.error))
651            }
652            StatusCode::BAD_REQUEST => {
653                let error: ErrorResponse = response.json().await?;
654                Err(GatewayError::BadRequest(error.error))
655            }
656            StatusCode::INTERNAL_SERVER_ERROR => {
657                let error: ErrorResponse = response.json().await?;
658                Err(GatewayError::InternalError(error.error))
659            }
660            _ => Err(GatewayError::Other(Box::new(std::io::Error::other(
661                format!("Unexpected status code: {}", response.status()),
662            )))),
663        }
664    }
665
666    async fn list_models_by_provider(
667        &self,
668        provider: Provider,
669    ) -> Result<ListModelsResponse, GatewayError> {
670        let url = format!("{}/models?provider={}", self.base_url, provider);
671        let mut request = self.client.get(&url);
672        if let Some(token) = &self.token {
673            request = self.client.get(&url).bearer_auth(token);
674        }
675
676        let response = request.send().await?;
677        match response.status() {
678            StatusCode::OK => {
679                let json_response: ListModelsResponse = response.json().await?;
680                Ok(json_response)
681            }
682            StatusCode::UNAUTHORIZED => {
683                let error: ErrorResponse = response.json().await?;
684                Err(GatewayError::Unauthorized(error.error))
685            }
686            StatusCode::BAD_REQUEST => {
687                let error: ErrorResponse = response.json().await?;
688                Err(GatewayError::BadRequest(error.error))
689            }
690            StatusCode::INTERNAL_SERVER_ERROR => {
691                let error: ErrorResponse = response.json().await?;
692                Err(GatewayError::InternalError(error.error))
693            }
694            _ => Err(GatewayError::Other(Box::new(std::io::Error::other(
695                format!("Unexpected status code: {}", response.status()),
696            )))),
697        }
698    }
699
700    async fn generate_content(
701        &self,
702        provider: Provider,
703        model: &str,
704        messages: Vec<Message>,
705    ) -> Result<CreateChatCompletionResponse, GatewayError> {
706        let url = format!("{}/chat/completions?provider={}", self.base_url, provider);
707        let mut request = self.client.post(&url);
708        if let Some(token) = &self.token {
709            request = request.bearer_auth(token);
710        }
711
712        let request_payload = CreateChatCompletionRequest {
713            model: model.to_string(),
714            messages,
715            stream: false,
716            tools: self.tools.clone(),
717            max_tokens: self.max_tokens,
718            reasoning_format: None,
719        };
720
721        let response = request.json(&request_payload).send().await?;
722
723        match response.status() {
724            StatusCode::OK => Ok(response.json().await?),
725            StatusCode::BAD_REQUEST => {
726                let error: ErrorResponse = response.json().await?;
727                Err(GatewayError::BadRequest(error.error))
728            }
729            StatusCode::UNAUTHORIZED => {
730                let error: ErrorResponse = response.json().await?;
731                Err(GatewayError::Unauthorized(error.error))
732            }
733            StatusCode::INTERNAL_SERVER_ERROR => {
734                let error: ErrorResponse = response.json().await?;
735                Err(GatewayError::InternalError(error.error))
736            }
737            status => Err(GatewayError::Other(Box::new(std::io::Error::other(
738                format!("Unexpected status code: {status}"),
739            )))),
740        }
741    }
742
743    /// Stream content generation directly using the backend SSE stream.
744    fn generate_content_stream(
745        &self,
746        provider: Provider,
747        model: &str,
748        messages: Vec<Message>,
749    ) -> impl Stream<Item = Result<SSEvents, GatewayError>> + Send {
750        let client = self.client.clone();
751        let base_url = self.base_url.clone();
752        let url = format!(
753            "{}/chat/completions?provider={}",
754            base_url,
755            provider.to_string().to_lowercase()
756        );
757
758        let request = CreateChatCompletionRequest {
759            model: model.to_string(),
760            messages,
761            stream: true,
762            tools: None,
763            max_tokens: None,
764            reasoning_format: None,
765        };
766
767        async_stream::try_stream! {
768            let response = client.post(&url).json(&request).send().await?;
769            let mut stream = response.bytes_stream();
770            let mut current_event: Option<String> = None;
771            let mut current_data: Option<String> = None;
772
773            while let Some(chunk) = stream.next().await {
774                let chunk = chunk?;
775                let chunk_str = String::from_utf8_lossy(&chunk);
776
777                for line in chunk_str.lines() {
778                    if line.is_empty() && current_data.is_some() {
779                        yield SSEvents {
780                            data: current_data.take().unwrap(),
781                            event: current_event.take(),
782                            retry: None, // TODO - implement this, for now it's not implemented in the backend
783                        };
784                        continue;
785                    }
786
787                    if let Some(event) = line.strip_prefix("event:") {
788                        current_event = Some(event.trim().to_string());
789                    } else if let Some(data) = line.strip_prefix("data:") {
790                        let processed_data = data.strip_suffix('\n').unwrap_or(data);
791                        current_data = Some(processed_data.trim().to_string());
792                    }
793                }
794            }
795        }
796    }
797
798    async fn list_tools(&self) -> Result<ListToolsResponse, GatewayError> {
799        let url = format!("{}/mcp/tools", self.base_url);
800        let mut request = self.client.get(&url);
801        if let Some(token) = &self.token {
802            request = request.bearer_auth(token);
803        }
804
805        let response = request.send().await?;
806        match response.status() {
807            StatusCode::OK => {
808                let json_response: ListToolsResponse = response.json().await?;
809                Ok(json_response)
810            }
811            StatusCode::UNAUTHORIZED => {
812                let error: ErrorResponse = response.json().await?;
813                Err(GatewayError::Unauthorized(error.error))
814            }
815            StatusCode::BAD_REQUEST => {
816                let error: ErrorResponse = response.json().await?;
817                Err(GatewayError::BadRequest(error.error))
818            }
819            StatusCode::FORBIDDEN => {
820                let error: ErrorResponse = response.json().await?;
821                Err(GatewayError::Forbidden(error.error))
822            }
823            StatusCode::INTERNAL_SERVER_ERROR => {
824                let error: ErrorResponse = response.json().await?;
825                Err(GatewayError::InternalError(error.error))
826            }
827            _ => Err(GatewayError::Other(Box::new(std::io::Error::other(
828                format!("Unexpected status code: {}", response.status()),
829            )))),
830        }
831    }
832
833    async fn health_check(&self) -> Result<bool, GatewayError> {
834        let url = format!("{}/health", self.base_url);
835
836        let response = self.client.get(&url).send().await?;
837        match response.status() {
838            StatusCode::OK => Ok(true),
839            _ => Ok(false),
840        }
841    }
842}
843
844#[cfg(test)]
845mod tests {
846    use crate::{
847        CreateChatCompletionRequest, CreateChatCompletionResponse,
848        CreateChatCompletionStreamResponse, FinishReason, FunctionObject, GatewayError,
849        InferenceGatewayAPI, InferenceGatewayClient, Message, MessageRole, Provider, Tool,
850        ToolType,
851    };
852    use futures_util::{pin_mut, StreamExt};
853    use mockito::{Matcher, Server};
854    use serde_json::json;
855
856    #[test]
857    fn test_provider_serialization() {
858        let providers = vec![
859            (Provider::Ollama, "ollama"),
860            (Provider::OllamaCloud, "ollama_cloud"),
861            (Provider::Groq, "groq"),
862            (Provider::OpenAI, "openai"),
863            (Provider::Cloudflare, "cloudflare"),
864            (Provider::Cohere, "cohere"),
865            (Provider::Anthropic, "anthropic"),
866            (Provider::Deepseek, "deepseek"),
867            (Provider::Google, "google"),
868            (Provider::Mistral, "mistral"),
869        ];
870
871        for (provider, expected) in providers {
872            let json = serde_json::to_string(&provider).unwrap();
873            assert_eq!(json, format!("\"{}\"", expected));
874        }
875    }
876
877    #[test]
878    fn test_provider_deserialization() {
879        let test_cases = vec![
880            ("\"ollama\"", Provider::Ollama),
881            ("\"ollama_cloud\"", Provider::OllamaCloud),
882            ("\"groq\"", Provider::Groq),
883            ("\"openai\"", Provider::OpenAI),
884            ("\"cloudflare\"", Provider::Cloudflare),
885            ("\"cohere\"", Provider::Cohere),
886            ("\"anthropic\"", Provider::Anthropic),
887            ("\"deepseek\"", Provider::Deepseek),
888            ("\"google\"", Provider::Google),
889            ("\"mistral\"", Provider::Mistral),
890        ];
891
892        for (json, expected) in test_cases {
893            let provider: Provider = serde_json::from_str(json).unwrap();
894            assert_eq!(provider, expected);
895        }
896    }
897
898    #[test]
899    fn test_message_serialization_with_tool_call_id() {
900        let message_with_tool = Message {
901            role: MessageRole::Tool,
902            content: "The weather is sunny".to_string(),
903            tool_call_id: Some("call_123".to_string()),
904            ..Default::default()
905        };
906
907        let serialized = serde_json::to_string(&message_with_tool).unwrap();
908        let expected_with_tool =
909            r#"{"role":"tool","content":"The weather is sunny","tool_call_id":"call_123"}"#;
910        assert_eq!(serialized, expected_with_tool);
911
912        let message_without_tool = Message {
913            role: MessageRole::User,
914            content: "What's the weather?".to_string(),
915            ..Default::default()
916        };
917
918        let serialized = serde_json::to_string(&message_without_tool).unwrap();
919        let expected_without_tool = r#"{"role":"user","content":"What's the weather?"}"#;
920        assert_eq!(serialized, expected_without_tool);
921
922        let deserialized: Message = serde_json::from_str(expected_with_tool).unwrap();
923        assert_eq!(deserialized.role, MessageRole::Tool);
924        assert_eq!(deserialized.content, "The weather is sunny");
925        assert_eq!(deserialized.tool_call_id, Some("call_123".to_string()));
926
927        let deserialized: Message = serde_json::from_str(expected_without_tool).unwrap();
928        assert_eq!(deserialized.role, MessageRole::User);
929        assert_eq!(deserialized.content, "What's the weather?");
930        assert_eq!(deserialized.tool_call_id, None);
931    }
932
933    #[test]
934    fn test_provider_display() {
935        let providers = vec![
936            (Provider::Ollama, "ollama"),
937            (Provider::OllamaCloud, "ollama_cloud"),
938            (Provider::Groq, "groq"),
939            (Provider::OpenAI, "openai"),
940            (Provider::Cloudflare, "cloudflare"),
941            (Provider::Cohere, "cohere"),
942            (Provider::Anthropic, "anthropic"),
943            (Provider::Deepseek, "deepseek"),
944            (Provider::Google, "google"),
945            (Provider::Mistral, "mistral"),
946        ];
947
948        for (provider, expected) in providers {
949            assert_eq!(provider.to_string(), expected);
950        }
951    }
952
953    #[test]
954    fn test_google_provider_case_insensitive() {
955        let test_cases = vec!["google", "Google", "GOOGLE", "GoOgLe"];
956
957        for test_case in test_cases {
958            let provider: Result<Provider, _> = test_case.try_into();
959            assert!(provider.is_ok(), "Failed to parse: {}", test_case);
960            assert_eq!(provider.unwrap(), Provider::Google);
961        }
962
963        let json_cases = vec![r#""google""#, r#""Google""#, r#""GOOGLE""#];
964
965        for json_case in json_cases {
966            let provider: Provider = serde_json::from_str(json_case).unwrap();
967            assert_eq!(provider, Provider::Google);
968        }
969
970        assert_eq!(Provider::Google.to_string(), "google");
971    }
972
973    #[test]
974    fn test_generate_request_serialization() {
975        let request_payload = CreateChatCompletionRequest {
976            model: "llama3.2:1b".to_string(),
977            messages: vec![
978                Message {
979                    role: MessageRole::System,
980                    content: "You are a helpful assistant.".to_string(),
981                    ..Default::default()
982                },
983                Message {
984                    role: MessageRole::User,
985                    content: "What is the current weather in Toronto?".to_string(),
986                    ..Default::default()
987                },
988            ],
989            stream: false,
990            tools: Some(vec![Tool {
991                r#type: ToolType::Function,
992                function: FunctionObject {
993                    name: "get_current_weather".to_string(),
994                    description: "Get the current weather of a city".to_string(),
995                    parameters: json!({
996                        "type": "object",
997                        "properties": {
998                            "city": {
999                                "type": "string",
1000                                "description": "The name of the city"
1001                            }
1002                        },
1003                        "required": ["city"]
1004                    }),
1005                },
1006            }]),
1007            max_tokens: None,
1008            reasoning_format: None,
1009        };
1010
1011        let serialized = serde_json::to_string_pretty(&request_payload).unwrap();
1012        let expected = r#"{
1013      "model": "llama3.2:1b",
1014      "messages": [
1015        {
1016          "role": "system",
1017          "content": "You are a helpful assistant."
1018        },
1019        {
1020          "role": "user",
1021          "content": "What is the current weather in Toronto?"
1022        }
1023      ],
1024      "stream": false,
1025      "tools": [
1026        {
1027          "type": "function",
1028          "function": {
1029            "name": "get_current_weather",
1030            "description": "Get the current weather of a city",
1031            "parameters": {
1032              "type": "object",
1033              "properties": {
1034                "city": {
1035                  "type": "string",
1036                  "description": "The name of the city"
1037                }
1038              },
1039              "required": ["city"]
1040            }
1041          }
1042        }
1043      ]
1044    }"#;
1045
1046        assert_eq!(
1047            serde_json::from_str::<serde_json::Value>(&serialized).unwrap(),
1048            serde_json::from_str::<serde_json::Value>(expected).unwrap()
1049        );
1050    }
1051
1052    #[tokio::test]
1053    async fn test_authentication_header() -> Result<(), GatewayError> {
1054        let mut server = Server::new_async().await;
1055
1056        let mock_response = r#"{
1057            "object": "list",
1058            "data": []
1059        }"#;
1060
1061        let mock_with_auth = server
1062            .mock("GET", "/v1/models")
1063            .match_header("authorization", "Bearer test-token")
1064            .with_status(200)
1065            .with_header("content-type", "application/json")
1066            .with_body(mock_response)
1067            .expect(1)
1068            .create();
1069
1070        let base_url = format!("{}/v1", server.url());
1071        let client = InferenceGatewayClient::new(&base_url).with_token("test-token");
1072        client.list_models().await?;
1073        mock_with_auth.assert();
1074
1075        let mock_without_auth = server
1076            .mock("GET", "/v1/models")
1077            .match_header("authorization", Matcher::Missing)
1078            .with_status(200)
1079            .with_header("content-type", "application/json")
1080            .with_body(mock_response)
1081            .expect(1)
1082            .create();
1083
1084        let base_url = format!("{}/v1", server.url());
1085        let client = InferenceGatewayClient::new(&base_url);
1086        client.list_models().await?;
1087        mock_without_auth.assert();
1088
1089        Ok(())
1090    }
1091
1092    #[tokio::test]
1093    async fn test_unauthorized_error() -> Result<(), GatewayError> {
1094        let mut server = Server::new_async().await;
1095
1096        let raw_json_response = r#"{
1097            "error": "Invalid token"
1098        }"#;
1099
1100        let mock = server
1101            .mock("GET", "/v1/models")
1102            .with_status(401)
1103            .with_header("content-type", "application/json")
1104            .with_body(raw_json_response)
1105            .create();
1106
1107        let base_url = format!("{}/v1", server.url());
1108        let client = InferenceGatewayClient::new(&base_url);
1109        let error = client.list_models().await.unwrap_err();
1110
1111        assert!(matches!(error, GatewayError::Unauthorized(_)));
1112        if let GatewayError::Unauthorized(msg) = error {
1113            assert_eq!(msg, "Invalid token");
1114        }
1115        mock.assert();
1116
1117        Ok(())
1118    }
1119
1120    #[tokio::test]
1121    async fn test_list_models() -> Result<(), GatewayError> {
1122        let mut server = Server::new_async().await;
1123
1124        let raw_response_json = r#"{
1125            "object": "list",
1126            "data": [
1127                {
1128                    "id": "llama2",
1129                    "object": "model",
1130                    "created": 1630000001,
1131                    "owned_by": "ollama",
1132                    "served_by": "ollama"
1133                }
1134            ]
1135        }"#;
1136
1137        let mock = server
1138            .mock("GET", "/v1/models")
1139            .with_status(200)
1140            .with_header("content-type", "application/json")
1141            .with_body(raw_response_json)
1142            .create();
1143
1144        let base_url = format!("{}/v1", server.url());
1145        let client = InferenceGatewayClient::new(&base_url);
1146        let response = client.list_models().await?;
1147
1148        assert!(response.provider.is_none());
1149        assert_eq!(response.object, "list");
1150        assert_eq!(response.data.len(), 1);
1151        assert_eq!(response.data[0].id, "llama2");
1152        mock.assert();
1153
1154        Ok(())
1155    }
1156
1157    #[tokio::test]
1158    async fn test_list_models_by_provider() -> Result<(), GatewayError> {
1159        let mut server = Server::new_async().await;
1160
1161        let raw_json_response = r#"{
1162            "provider":"ollama",
1163            "object":"list",
1164            "data": [
1165                {
1166                    "id": "llama2",
1167                    "object": "model",
1168                    "created": 1630000001,
1169                    "owned_by": "ollama",
1170                    "served_by": "ollama"
1171                }
1172            ]
1173        }"#;
1174
1175        let mock = server
1176            .mock("GET", "/v1/models?provider=ollama")
1177            .with_status(200)
1178            .with_header("content-type", "application/json")
1179            .with_body(raw_json_response)
1180            .create();
1181
1182        let base_url = format!("{}/v1", server.url());
1183        let client = InferenceGatewayClient::new(&base_url);
1184        let response = client.list_models_by_provider(Provider::Ollama).await?;
1185
1186        assert!(response.provider.is_some());
1187        assert_eq!(response.provider, Some(Provider::Ollama));
1188        assert_eq!(response.data[0].id, "llama2");
1189        mock.assert();
1190
1191        Ok(())
1192    }
1193
1194    #[tokio::test]
1195    async fn test_generate_content() -> Result<(), GatewayError> {
1196        let mut server = Server::new_async().await;
1197
1198        let raw_json_response = r#"{
1199            "id": "chatcmpl-456",
1200            "object": "chat.completion",
1201            "created": 1630000001,
1202            "model": "mixtral-8x7b",
1203            "choices": [
1204                {
1205                    "index": 0,
1206                    "finish_reason": "stop",
1207                    "logprobs": null,
1208                    "message": {
1209                        "role": "assistant",
1210                        "content": "Hellloooo"
1211                    }
1212                }
1213            ]
1214        }"#;
1215
1216        let mock = server
1217            .mock("POST", "/v1/chat/completions?provider=ollama")
1218            .with_status(200)
1219            .with_header("content-type", "application/json")
1220            .with_body(raw_json_response)
1221            .create();
1222
1223        let base_url = format!("{}/v1", server.url());
1224        let client = InferenceGatewayClient::new(&base_url);
1225
1226        let messages = vec![Message {
1227            role: MessageRole::User,
1228            content: "Hello".to_string(),
1229            ..Default::default()
1230        }];
1231        let response = client
1232            .generate_content(Provider::Ollama, "llama2", messages)
1233            .await?;
1234
1235        assert_eq!(response.choices[0].message.role, MessageRole::Assistant);
1236        assert_eq!(response.choices[0].message.content, "Hellloooo");
1237        mock.assert();
1238
1239        Ok(())
1240    }
1241
1242    #[tokio::test]
1243    async fn test_generate_content_serialization() -> Result<(), GatewayError> {
1244        let mut server = Server::new_async().await;
1245
1246        let raw_json = r#"{
1247            "id": "chatcmpl-456",
1248            "object": "chat.completion",
1249            "created": 1630000001,
1250            "model": "mixtral-8x7b",
1251            "choices": [
1252                {
1253                    "index": 0,
1254                    "finish_reason": "stop",
1255                    "logprobs": null,
1256                    "message": {
1257                        "role": "assistant",
1258                        "content": "Hello"
1259                    }
1260                }
1261            ]
1262        }"#;
1263
1264        let mock = server
1265            .mock("POST", "/v1/chat/completions?provider=groq")
1266            .with_status(200)
1267            .with_header("content-type", "application/json")
1268            .with_body(raw_json)
1269            .create();
1270
1271        let base_url = format!("{}/v1", server.url());
1272        let client = InferenceGatewayClient::new(&base_url);
1273
1274        let direct_parse: Result<CreateChatCompletionResponse, _> = serde_json::from_str(raw_json);
1275        assert!(
1276            direct_parse.is_ok(),
1277            "Direct JSON parse failed: {:?}",
1278            direct_parse.err()
1279        );
1280
1281        let messages = vec![Message {
1282            role: MessageRole::User,
1283            content: "Hello".to_string(),
1284            ..Default::default()
1285        }];
1286
1287        let response = client
1288            .generate_content(Provider::Groq, "mixtral-8x7b", messages)
1289            .await?;
1290
1291        assert_eq!(response.choices[0].message.role, MessageRole::Assistant);
1292        assert_eq!(response.choices[0].message.content, "Hello");
1293
1294        mock.assert();
1295        Ok(())
1296    }
1297
1298    #[tokio::test]
1299    async fn test_generate_content_error_response() -> Result<(), GatewayError> {
1300        let mut server = Server::new_async().await;
1301
1302        let raw_json_response = r#"{
1303            "error":"Invalid request"
1304        }"#;
1305
1306        let mock = server
1307            .mock("POST", "/v1/chat/completions?provider=groq")
1308            .with_status(400)
1309            .with_header("content-type", "application/json")
1310            .with_body(raw_json_response)
1311            .create();
1312
1313        let base_url = format!("{}/v1", server.url());
1314        let client = InferenceGatewayClient::new(&base_url);
1315        let messages = vec![Message {
1316            role: MessageRole::User,
1317            content: "Hello".to_string(),
1318            ..Default::default()
1319        }];
1320        let error = client
1321            .generate_content(Provider::Groq, "mixtral-8x7b", messages)
1322            .await
1323            .unwrap_err();
1324
1325        assert!(matches!(error, GatewayError::BadRequest(_)));
1326        if let GatewayError::BadRequest(msg) = error {
1327            assert_eq!(msg, "Invalid request");
1328        }
1329        mock.assert();
1330
1331        Ok(())
1332    }
1333
1334    #[tokio::test]
1335    async fn test_gateway_errors() -> Result<(), GatewayError> {
1336        let mut server: mockito::ServerGuard = Server::new_async().await;
1337
1338        let unauthorized_mock = server
1339            .mock("GET", "/v1/models")
1340            .with_status(401)
1341            .with_header("content-type", "application/json")
1342            .with_body(r#"{"error":"Invalid token"}"#)
1343            .create();
1344
1345        let base_url = format!("{}/v1", server.url());
1346        let client = InferenceGatewayClient::new(&base_url);
1347        match client.list_models().await {
1348            Err(GatewayError::Unauthorized(msg)) => assert_eq!(msg, "Invalid token"),
1349            _ => panic!("Expected Unauthorized error"),
1350        }
1351        unauthorized_mock.assert();
1352
1353        let bad_request_mock = server
1354            .mock("GET", "/v1/models")
1355            .with_status(400)
1356            .with_header("content-type", "application/json")
1357            .with_body(r#"{"error":"Invalid provider"}"#)
1358            .create();
1359
1360        match client.list_models().await {
1361            Err(GatewayError::BadRequest(msg)) => assert_eq!(msg, "Invalid provider"),
1362            _ => panic!("Expected BadRequest error"),
1363        }
1364        bad_request_mock.assert();
1365
1366        let internal_error_mock = server
1367            .mock("GET", "/v1/models")
1368            .with_status(500)
1369            .with_header("content-type", "application/json")
1370            .with_body(r#"{"error":"Internal server error occurred"}"#)
1371            .create();
1372
1373        match client.list_models().await {
1374            Err(GatewayError::InternalError(msg)) => {
1375                assert_eq!(msg, "Internal server error occurred")
1376            }
1377            _ => panic!("Expected InternalError error"),
1378        }
1379        internal_error_mock.assert();
1380
1381        Ok(())
1382    }
1383
1384    #[tokio::test]
1385    async fn test_generate_content_case_insensitive() -> Result<(), GatewayError> {
1386        let mut server = Server::new_async().await;
1387
1388        let raw_json = r#"{
1389            "id": "chatcmpl-456",
1390            "object": "chat.completion",
1391            "created": 1630000001,
1392            "model": "mixtral-8x7b",
1393            "choices": [
1394                {
1395                    "index": 0,
1396                    "finish_reason": "stop",
1397                    "logprobs": null,
1398                    "message": {
1399                        "role": "assistant",
1400                        "content": "Hello"
1401                    }
1402                }
1403            ]
1404        }"#;
1405
1406        let mock = server
1407            .mock("POST", "/v1/chat/completions?provider=groq")
1408            .with_status(200)
1409            .with_header("content-type", "application/json")
1410            .with_body(raw_json)
1411            .create();
1412
1413        let base_url = format!("{}/v1", server.url());
1414        let client = InferenceGatewayClient::new(&base_url);
1415
1416        let messages = vec![Message {
1417            role: MessageRole::User,
1418            content: "Hello".to_string(),
1419            ..Default::default()
1420        }];
1421
1422        let response = client
1423            .generate_content(Provider::Groq, "mixtral-8x7b", messages)
1424            .await?;
1425
1426        assert_eq!(response.choices[0].message.role, MessageRole::Assistant);
1427        assert_eq!(response.choices[0].message.content, "Hello");
1428        assert_eq!(response.model, "mixtral-8x7b");
1429        assert_eq!(response.object, "chat.completion");
1430        mock.assert();
1431
1432        Ok(())
1433    }
1434
1435    #[tokio::test]
1436    async fn test_generate_content_stream() -> Result<(), GatewayError> {
1437        let mut server = Server::new_async().await;
1438
1439        let mock = server
1440            .mock("POST", "/v1/chat/completions?provider=groq")
1441            .with_status(200)
1442            .with_header("content-type", "text/event-stream")
1443            .with_chunked_body(move |writer| -> std::io::Result<()> {
1444                let events = vec![
1445                    format!("data: {}\n\n", r#"{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"mixtral-8x7b","system_fingerprint":"fp_","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]}"#),
1446                    format!("data: {}\n\n", r#"{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268191,"model":"mixtral-8x7b","system_fingerprint":"fp_","choices":[{"index":0,"delta":{"role":"assistant","content":" World"},"finish_reason":null}]}"#),
1447                    format!("data: {}\n\n", r#"{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268192,"model":"mixtral-8x7b","system_fingerprint":"fp_","choices":[{"index":0,"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":17,"completion_tokens":40,"total_tokens":57}}"#),
1448                    format!("data: [DONE]\n\n")
1449                ];
1450                for event in events {
1451                    writer.write_all(event.as_bytes())?;
1452                }
1453                Ok(())
1454            })
1455            .create();
1456
1457        let base_url = format!("{}/v1", server.url());
1458        let client = InferenceGatewayClient::new(&base_url);
1459
1460        let messages = vec![Message {
1461            role: MessageRole::User,
1462            content: "Test message".to_string(),
1463            ..Default::default()
1464        }];
1465
1466        let stream = client.generate_content_stream(Provider::Groq, "mixtral-8x7b", messages);
1467        pin_mut!(stream);
1468        while let Some(result) = stream.next().await {
1469            let result = result?;
1470            let generate_response: CreateChatCompletionStreamResponse =
1471                serde_json::from_str(&result.data)
1472                    .expect("Failed to parse CreateChatCompletionResponse");
1473
1474            if generate_response.choices[0].finish_reason.is_some() {
1475                assert_eq!(
1476                    generate_response.choices[0].finish_reason.as_ref().unwrap(),
1477                    &FinishReason::Stop
1478                );
1479                break;
1480            }
1481
1482            if let Some(content) = &generate_response.choices[0].delta.content {
1483                assert!(matches!(content.as_str(), "Hello" | " World"));
1484            }
1485            if let Some(role) = &generate_response.choices[0].delta.role {
1486                assert_eq!(role, &MessageRole::Assistant);
1487            }
1488        }
1489
1490        mock.assert();
1491        Ok(())
1492    }
1493
1494    #[tokio::test]
1495    async fn test_generate_content_stream_error() -> Result<(), GatewayError> {
1496        let mut server = Server::new_async().await;
1497
1498        let mock = server
1499            .mock("POST", "/v1/chat/completions?provider=groq")
1500            .with_status(400)
1501            .with_header("content-type", "application/json")
1502            .with_chunked_body(move |writer| -> std::io::Result<()> {
1503                let events = vec![format!(
1504                    "event: {}\ndata: {}\nretry: {}\n\n",
1505                    r#"error"#, r#"{"error":"Invalid request"}"#, r#"1000"#,
1506                )];
1507                for event in events {
1508                    writer.write_all(event.as_bytes())?;
1509                }
1510                Ok(())
1511            })
1512            .expect_at_least(1)
1513            .create();
1514
1515        let base_url = format!("{}/v1", server.url());
1516        let client = InferenceGatewayClient::new(&base_url);
1517
1518        let messages = vec![Message {
1519            role: MessageRole::User,
1520            content: "Test message".to_string(),
1521            ..Default::default()
1522        }];
1523
1524        let stream = client.generate_content_stream(Provider::Groq, "mixtral-8x7b", messages);
1525
1526        pin_mut!(stream);
1527        while let Some(result) = stream.next().await {
1528            let result = result?;
1529            assert!(result.event.is_some());
1530            assert_eq!(result.event.unwrap(), "error");
1531            assert!(result.data.contains("Invalid request"));
1532            assert!(result.retry.is_none());
1533        }
1534
1535        mock.assert();
1536        Ok(())
1537    }
1538
1539    #[tokio::test]
1540    async fn test_generate_content_with_tools() -> Result<(), GatewayError> {
1541        let mut server = Server::new_async().await;
1542
1543        let raw_json_response = r#"{
1544            "id": "chatcmpl-123",
1545            "object": "chat.completion",
1546            "created": 1630000000,
1547            "model": "deepseek-r1-distill-llama-70b",
1548            "choices": [
1549                {
1550                    "index": 0,
1551                    "finish_reason": "tool_calls",
1552                    "logprobs": null,
1553                    "message": {
1554                        "role": "assistant",
1555                        "content": "Let me check the weather for you.",
1556                        "tool_calls": [
1557                            {
1558                                "id": "1234",
1559                                "type": "function",
1560                                "function": {
1561                                    "name": "get_weather",
1562                                    "arguments": "{\"location\": \"London\"}"
1563                                }
1564                            }
1565                        ]
1566                    }
1567                }
1568            ]
1569        }"#;
1570
1571        let mock = server
1572            .mock("POST", "/v1/chat/completions?provider=groq")
1573            .with_status(200)
1574            .with_header("content-type", "application/json")
1575            .with_body(raw_json_response)
1576            .create();
1577
1578        let tools = vec![Tool {
1579            r#type: ToolType::Function,
1580            function: FunctionObject {
1581                name: "get_weather".to_string(),
1582                description: "Get the weather for a location".to_string(),
1583                parameters: json!({
1584                    "type": "object",
1585                    "properties": {
1586                        "location": {
1587                            "type": "string",
1588                            "description": "The city name"
1589                        }
1590                    },
1591                    "required": ["location"]
1592                }),
1593            },
1594        }];
1595
1596        let base_url = format!("{}/v1", server.url());
1597        let client = InferenceGatewayClient::new(&base_url).with_tools(Some(tools));
1598
1599        let messages = vec![Message {
1600            role: MessageRole::User,
1601            content: "What's the weather in London?".to_string(),
1602            ..Default::default()
1603        }];
1604
1605        let response = client
1606            .generate_content(Provider::Groq, "deepseek-r1-distill-llama-70b", messages)
1607            .await?;
1608
1609        assert_eq!(response.choices[0].message.role, MessageRole::Assistant);
1610        assert_eq!(
1611            response.choices[0].message.content,
1612            "Let me check the weather for you."
1613        );
1614
1615        let tool_calls = response.choices[0].message.tool_calls.as_ref().unwrap();
1616        assert_eq!(tool_calls.len(), 1);
1617        assert_eq!(tool_calls[0].function.name, "get_weather");
1618
1619        let params = tool_calls[0]
1620            .function
1621            .parse_arguments()
1622            .expect("Failed to parse function arguments");
1623        assert_eq!(params["location"].as_str().unwrap(), "London");
1624
1625        mock.assert();
1626        Ok(())
1627    }
1628
1629    #[tokio::test]
1630    async fn test_generate_content_without_tools() -> Result<(), GatewayError> {
1631        let mut server = Server::new_async().await;
1632
1633        let raw_json_response = r#"{
1634            "id": "chatcmpl-123",
1635            "object": "chat.completion",
1636            "created": 1630000000,
1637            "model": "gpt-4",
1638            "choices": [
1639                {
1640                    "index": 0,
1641                    "finish_reason": "stop",
1642                    "logprobs": null,
1643                    "message": {
1644                        "role": "assistant",
1645                        "content": "Hello!"
1646                    }
1647                }
1648            ]
1649        }"#;
1650
1651        let mock = server
1652            .mock("POST", "/v1/chat/completions?provider=openai")
1653            .with_status(200)
1654            .with_header("content-type", "application/json")
1655            .with_body(raw_json_response)
1656            .create();
1657
1658        let base_url = format!("{}/v1", server.url());
1659        let client = InferenceGatewayClient::new(&base_url);
1660
1661        let messages = vec![Message {
1662            role: MessageRole::User,
1663            content: "Hi".to_string(),
1664            ..Default::default()
1665        }];
1666
1667        let response = client
1668            .generate_content(Provider::OpenAI, "gpt-4", messages)
1669            .await?;
1670
1671        assert_eq!(response.model, "gpt-4");
1672        assert_eq!(response.choices[0].message.content, "Hello!");
1673        assert_eq!(response.choices[0].message.role, MessageRole::Assistant);
1674        assert!(response.choices[0].message.tool_calls.is_none());
1675
1676        mock.assert();
1677        Ok(())
1678    }
1679
1680    #[tokio::test]
1681    async fn test_generate_content_with_tools_payload() -> Result<(), GatewayError> {
1682        let mut server = Server::new_async().await;
1683
1684        let raw_request_body = r#"{
1685            "model": "deepseek-r1-distill-llama-70b",
1686            "messages": [
1687                {
1688                    "role": "system",
1689                    "content": "You are a helpful assistant."
1690                },
1691                {
1692                    "role": "user",
1693                    "content": "What is the current weather in Toronto?"
1694                }
1695            ],
1696            "stream": false,
1697            "tools": [
1698                {
1699                    "type": "function",
1700                    "function": {
1701                        "name": "get_current_weather",
1702                        "description": "Get the current weather of a city",
1703                        "parameters": {
1704                            "type": "object",
1705                            "properties": {
1706                                "city": {
1707                                    "type": "string",
1708                                    "description": "The name of the city"
1709                                }
1710                            },
1711                            "required": ["city"]
1712                        }
1713                    }
1714                }
1715            ]
1716        }"#;
1717
1718        let raw_json_response = r#"{
1719            "id": "1234",
1720            "object": "chat.completion",
1721            "created": 1630000000,
1722            "model": "deepseek-r1-distill-llama-70b",
1723            "choices": [
1724                {
1725                    "index": 0,
1726                    "finish_reason": "stop",
1727                    "logprobs": null,
1728                    "message": {
1729                        "role": "assistant",
1730                        "content": "Let me check the weather for you",
1731                        "tool_calls": [
1732                            {
1733                                "id": "1234",
1734                                "type": "function",
1735                                "function": {
1736                                    "name": "get_current_weather",
1737                                    "arguments": "{\"city\": \"Toronto\"}"
1738                                }
1739                            }
1740                        ]
1741                    }
1742                }
1743            ]
1744        }"#;
1745
1746        let mock = server
1747            .mock("POST", "/v1/chat/completions?provider=groq")
1748            .with_status(200)
1749            .with_header("content-type", "application/json")
1750            .match_body(mockito::Matcher::JsonString(raw_request_body.to_string()))
1751            .with_body(raw_json_response)
1752            .create();
1753
1754        let tools = vec![Tool {
1755            r#type: ToolType::Function,
1756            function: FunctionObject {
1757                name: "get_current_weather".to_string(),
1758                description: "Get the current weather of a city".to_string(),
1759                parameters: json!({
1760                    "type": "object",
1761                    "properties": {
1762                        "city": {
1763                            "type": "string",
1764                            "description": "The name of the city"
1765                        }
1766                    },
1767                    "required": ["city"]
1768                }),
1769            },
1770        }];
1771
1772        let base_url = format!("{}/v1", server.url());
1773        let client = InferenceGatewayClient::new(&base_url);
1774
1775        let messages = vec![
1776            Message {
1777                role: MessageRole::System,
1778                content: "You are a helpful assistant.".to_string(),
1779                ..Default::default()
1780            },
1781            Message {
1782                role: MessageRole::User,
1783                content: "What is the current weather in Toronto?".to_string(),
1784                ..Default::default()
1785            },
1786        ];
1787
1788        let response = client
1789            .with_tools(Some(tools))
1790            .generate_content(Provider::Groq, "deepseek-r1-distill-llama-70b", messages)
1791            .await?;
1792
1793        assert_eq!(response.choices[0].message.role, MessageRole::Assistant);
1794        assert_eq!(
1795            response.choices[0].message.content,
1796            "Let me check the weather for you"
1797        );
1798        assert_eq!(
1799            response.choices[0]
1800                .message
1801                .tool_calls
1802                .as_ref()
1803                .unwrap()
1804                .len(),
1805            1
1806        );
1807
1808        mock.assert();
1809        Ok(())
1810    }
1811
1812    #[tokio::test]
1813    async fn test_generate_content_with_max_tokens() -> Result<(), GatewayError> {
1814        let mut server = Server::new_async().await;
1815
1816        let raw_json_response = r#"{
1817            "id": "chatcmpl-123",
1818            "object": "chat.completion",
1819            "created": 1630000000,
1820            "model": "mixtral-8x7b",
1821            "choices": [
1822                {
1823                    "index": 0,
1824                    "finish_reason": "stop",
1825                    "logprobs": null,
1826                    "message": {
1827                        "role": "assistant",
1828                        "content": "Here's a poem with 100 tokens..."
1829                    }
1830                }
1831            ]
1832        }"#;
1833
1834        let mock = server
1835            .mock("POST", "/v1/chat/completions?provider=groq")
1836            .with_status(200)
1837            .with_header("content-type", "application/json")
1838            .match_body(mockito::Matcher::JsonString(
1839                r#"{
1840                "model": "mixtral-8x7b",
1841                "messages": [{"role":"user","content":"Write a poem"}],
1842                "stream": false,
1843                "max_tokens": 100
1844            }"#
1845                .to_string(),
1846            ))
1847            .with_body(raw_json_response)
1848            .create();
1849
1850        let base_url = format!("{}/v1", server.url());
1851        let client = InferenceGatewayClient::new(&base_url).with_max_tokens(Some(100));
1852
1853        let messages = vec![Message {
1854            role: MessageRole::User,
1855            content: "Write a poem".to_string(),
1856            ..Default::default()
1857        }];
1858
1859        let response = client
1860            .generate_content(Provider::Groq, "mixtral-8x7b", messages)
1861            .await?;
1862
1863        assert_eq!(
1864            response.choices[0].message.content,
1865            "Here's a poem with 100 tokens..."
1866        );
1867        assert_eq!(response.model, "mixtral-8x7b");
1868        assert_eq!(response.created, 1630000000);
1869        assert_eq!(response.object, "chat.completion");
1870
1871        mock.assert();
1872        Ok(())
1873    }
1874
1875    #[tokio::test]
1876    async fn test_health_check() -> Result<(), GatewayError> {
1877        let mut server = Server::new_async().await;
1878        let mock = server.mock("GET", "/health").with_status(200).create();
1879
1880        let client = InferenceGatewayClient::new(&server.url());
1881        let is_healthy = client.health_check().await?;
1882
1883        assert!(is_healthy);
1884        mock.assert();
1885
1886        Ok(())
1887    }
1888
1889    #[tokio::test]
1890    async fn test_client_base_url_configuration() -> Result<(), GatewayError> {
1891        let mut custom_url_server = Server::new_async().await;
1892
1893        let custom_url_mock = custom_url_server
1894            .mock("GET", "/health")
1895            .with_status(200)
1896            .create();
1897
1898        let custom_client = InferenceGatewayClient::new(&custom_url_server.url());
1899        let is_healthy = custom_client.health_check().await?;
1900        assert!(is_healthy);
1901        custom_url_mock.assert();
1902
1903        let default_client = InferenceGatewayClient::new_default();
1904
1905        let default_url = "http://localhost:8080/v1";
1906        assert_eq!(default_client.base_url(), default_url);
1907
1908        Ok(())
1909    }
1910
1911    #[tokio::test]
1912    async fn test_list_tools() -> Result<(), GatewayError> {
1913        let mut server = Server::new_async().await;
1914
1915        let raw_response_json = r#"{
1916            "object": "list",
1917            "data": [
1918                {
1919                    "name": "read_file",
1920                    "description": "Read content from a file",
1921                    "server": "http://mcp-filesystem-server:8083/mcp",
1922                    "input_schema": {
1923                        "type": "object",
1924                        "properties": {
1925                            "file_path": {
1926                                "type": "string",
1927                                "description": "Path to the file to read"
1928                            }
1929                        },
1930                        "required": ["file_path"]
1931                    }
1932                },
1933                {
1934                    "name": "write_file",
1935                    "description": "Write content to a file",
1936                    "server": "http://mcp-filesystem-server:8083/mcp"
1937                }
1938            ]
1939        }"#;
1940
1941        let mock = server
1942            .mock("GET", "/v1/mcp/tools")
1943            .with_status(200)
1944            .with_header("content-type", "application/json")
1945            .with_body(raw_response_json)
1946            .create();
1947
1948        let base_url = format!("{}/v1", server.url());
1949        let client = InferenceGatewayClient::new(&base_url);
1950        let response = client.list_tools().await?;
1951
1952        assert_eq!(response.object, "list");
1953        assert_eq!(response.data.len(), 2);
1954
1955        // Test first tool with input_schema
1956        assert_eq!(response.data[0].name, "read_file");
1957        assert_eq!(response.data[0].description, "Read content from a file");
1958        assert_eq!(
1959            response.data[0].server,
1960            "http://mcp-filesystem-server:8083/mcp"
1961        );
1962        assert!(response.data[0].input_schema.is_some());
1963
1964        // Test second tool without input_schema
1965        assert_eq!(response.data[1].name, "write_file");
1966        assert_eq!(response.data[1].description, "Write content to a file");
1967        assert_eq!(
1968            response.data[1].server,
1969            "http://mcp-filesystem-server:8083/mcp"
1970        );
1971        assert!(response.data[1].input_schema.is_none());
1972
1973        mock.assert();
1974        Ok(())
1975    }
1976
1977    #[tokio::test]
1978    async fn test_list_tools_with_authentication() -> Result<(), GatewayError> {
1979        let mut server = Server::new_async().await;
1980
1981        let raw_response_json = r#"{
1982            "object": "list",
1983            "data": []
1984        }"#;
1985
1986        let mock = server
1987            .mock("GET", "/v1/mcp/tools")
1988            .match_header("authorization", "Bearer test-token")
1989            .with_status(200)
1990            .with_header("content-type", "application/json")
1991            .with_body(raw_response_json)
1992            .create();
1993
1994        let base_url = format!("{}/v1", server.url());
1995        let client = InferenceGatewayClient::new(&base_url).with_token("test-token");
1996        let response = client.list_tools().await?;
1997
1998        assert_eq!(response.object, "list");
1999        assert_eq!(response.data.len(), 0);
2000        mock.assert();
2001        Ok(())
2002    }
2003
2004    #[tokio::test]
2005    async fn test_list_tools_mcp_not_exposed() -> Result<(), GatewayError> {
2006        let mut server = Server::new_async().await;
2007
2008        let mock = server
2009            .mock("GET", "/v1/mcp/tools")
2010            .with_status(403)
2011            .with_header("content-type", "application/json")
2012            .with_body(
2013                r#"{"error":"MCP tools endpoint is not exposed. Set EXPOSE_MCP=true to enable."}"#,
2014            )
2015            .create();
2016
2017        let base_url = format!("{}/v1", server.url());
2018        let client = InferenceGatewayClient::new(&base_url);
2019
2020        match client.list_tools().await {
2021            Err(GatewayError::Forbidden(msg)) => {
2022                assert_eq!(
2023                    msg,
2024                    "MCP tools endpoint is not exposed. Set EXPOSE_MCP=true to enable."
2025                );
2026            }
2027            _ => panic!("Expected Forbidden error for MCP not exposed"),
2028        }
2029
2030        mock.assert();
2031        Ok(())
2032    }
2033}