inference_gateway_sdk/
lib.rs

1//! Inference Gateway SDK for Rust
2//!
3//! This crate provides a Rust client for the Inference Gateway API, allowing interaction
4//! with various LLM providers through a unified interface.
5
6use core::fmt;
7use std::future::Future;
8
9use futures_util::{Stream, StreamExt};
10use reqwest::{Client, StatusCode};
11use serde::{Deserialize, Serialize};
12
13use serde_json::Value;
14use thiserror::Error;
15
16/// Stream of Server-Sent Events (SSE) from the Inference Gateway API
17#[derive(Debug, Serialize, Deserialize)]
18pub struct SSEvents {
19    pub data: String,
20    pub event: Option<String>,
21    pub retry: Option<u64>,
22}
23
24/// Custom error types for the Inference Gateway SDK
25#[derive(Error, Debug)]
26pub enum GatewayError {
27    #[error("Unauthorized: {0}")]
28    Unauthorized(String),
29
30    #[error("Bad request: {0}")]
31    BadRequest(String),
32
33    #[error("Internal server error: {0}")]
34    InternalError(String),
35
36    #[error("Stream error: {0}")]
37    StreamError(reqwest::Error),
38
39    #[error("Decoding error: {0}")]
40    DecodingError(std::string::FromUtf8Error),
41
42    #[error("Request error: {0}")]
43    RequestError(#[from] reqwest::Error),
44
45    #[error("Deserialization error: {0}")]
46    DeserializationError(serde_json::Error),
47
48    #[error("Serialization error: {0}")]
49    SerializationError(#[from] serde_json::Error),
50
51    #[error("Other error: {0}")]
52    Other(#[from] Box<dyn std::error::Error + Send + Sync>),
53}
54
55#[derive(Debug, Deserialize)]
56struct ErrorResponse {
57    error: String,
58}
59
60/// Common model information
61#[derive(Debug, Serialize, Deserialize, Clone)]
62pub struct Model {
63    /// The model identifier
64    pub id: String,
65    /// The object type, usually "model"
66    pub object: Option<String>,
67    /// The Unix timestamp (in seconds) of when the model was created
68    pub created: Option<i64>,
69    /// The organization that owns the model
70    pub owned_by: Option<String>,
71    /// The provider that serves the model
72    pub served_by: Option<String>,
73}
74
75/// Response structure for listing models
76#[derive(Debug, Serialize, Deserialize)]
77pub struct ListModelsResponse {
78    /// The provider identifier
79    #[serde(skip_serializing_if = "Option::is_none")]
80    pub provider: Option<Provider>,
81    /// Response object type, usually "list"
82    pub object: String,
83    /// List of available models
84    pub data: Vec<Model>,
85}
86
87/// Supported LLM providers
88#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, Copy)]
89#[serde(rename_all = "lowercase")]
90pub enum Provider {
91    #[serde(alias = "Ollama", alias = "OLLAMA")]
92    Ollama,
93    #[serde(alias = "Groq", alias = "GROQ")]
94    Groq,
95    #[serde(alias = "OpenAI", alias = "OPENAI")]
96    OpenAI,
97    #[serde(alias = "Cloudflare", alias = "CLOUDFLARE")]
98    Cloudflare,
99    #[serde(alias = "Cohere", alias = "COHERE")]
100    Cohere,
101    #[serde(alias = "Anthropic", alias = "ANTHROPIC")]
102    Anthropic,
103}
104
105impl fmt::Display for Provider {
106    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
107        match self {
108            Provider::Ollama => write!(f, "ollama"),
109            Provider::Groq => write!(f, "groq"),
110            Provider::OpenAI => write!(f, "openai"),
111            Provider::Cloudflare => write!(f, "cloudflare"),
112            Provider::Cohere => write!(f, "cohere"),
113            Provider::Anthropic => write!(f, "anthropic"),
114        }
115    }
116}
117
118impl TryFrom<&str> for Provider {
119    type Error = GatewayError;
120
121    fn try_from(s: &str) -> Result<Self, Self::Error> {
122        match s.to_lowercase().as_str() {
123            "ollama" => Ok(Self::Ollama),
124            "groq" => Ok(Self::Groq),
125            "openai" => Ok(Self::OpenAI),
126            "cloudflare" => Ok(Self::Cloudflare),
127            "cohere" => Ok(Self::Cohere),
128            "anthropic" => Ok(Self::Anthropic),
129            _ => Err(GatewayError::BadRequest(format!("Unknown provider: {}", s))),
130        }
131    }
132}
133
134#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
135#[serde(rename_all = "lowercase")]
136pub enum MessageRole {
137    System,
138    #[default]
139    User,
140    Assistant,
141    Tool,
142}
143
144impl fmt::Display for MessageRole {
145    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
146        match self {
147            MessageRole::System => write!(f, "system"),
148            MessageRole::User => write!(f, "user"),
149            MessageRole::Assistant => write!(f, "assistant"),
150            MessageRole::Tool => write!(f, "tool"),
151        }
152    }
153}
154
155/// A message in a conversation with an LLM
156#[derive(Debug, Serialize, Deserialize, Clone, Default)]
157pub struct Message {
158    /// Role of the message sender ("system", "user", "assistant" or "tool")
159    pub role: MessageRole,
160    /// Content of the message
161    pub content: String,
162    /// The tools an LLM wants to invoke
163    #[serde(skip_serializing_if = "Option::is_none")]
164    pub tool_calls: Option<Vec<ChatCompletionMessageToolCall>>,
165    /// Unique identifier of the tool call
166    #[serde(skip_serializing_if = "Option::is_none")]
167    pub tool_call_id: Option<String>,
168    /// Reasoning behind the message
169    #[serde(skip_serializing_if = "Option::is_none")]
170    pub reasoning: Option<String>,
171}
172
173/// A tool call in a message response
174#[derive(Debug, Deserialize, Serialize, Clone)]
175pub struct ChatCompletionMessageToolCall {
176    /// Unique identifier of the tool call
177    pub id: String,
178    /// Type of the tool being called
179    #[serde(rename = "type")]
180    pub r#type: ChatCompletionToolType,
181    /// Function that was called
182    pub function: ChatCompletionMessageToolCallFunction,
183}
184
185/// Type of tool that can be called
186#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
187pub enum ChatCompletionToolType {
188    /// Function tool type
189    #[serde(rename = "function")]
190    Function,
191}
192
193/// Function details in a tool call
194#[derive(Debug, Deserialize, Serialize, Clone)]
195pub struct ChatCompletionMessageToolCallFunction {
196    /// Name of the function to call
197    pub name: String,
198    /// Arguments to the function in JSON string format
199    pub arguments: String,
200}
201
202// Add this helper method to make argument access more convenient
203impl ChatCompletionMessageToolCallFunction {
204    pub fn parse_arguments(&self) -> Result<serde_json::Value, serde_json::Error> {
205        serde_json::from_str(&self.arguments)
206    }
207}
208
209/// Tool function to call
210#[derive(Debug, Serialize, Deserialize, Clone)]
211pub struct FunctionObject {
212    pub name: String,
213    pub description: String,
214    pub parameters: Value,
215}
216
217/// Type of tool that can be used by the model
218#[derive(Debug, Serialize, Deserialize, Clone)]
219#[serde(rename_all = "lowercase")]
220pub enum ToolType {
221    Function,
222}
223
224/// Tool to use for the LLM toolbox
225#[derive(Debug, Serialize, Deserialize, Clone)]
226pub struct Tool {
227    pub r#type: ToolType,
228    pub function: FunctionObject,
229}
230
231/// Request payload for generating content
232#[derive(Debug, Serialize)]
233struct CreateChatCompletionRequest {
234    /// Name of the model
235    model: String,
236    /// Conversation history and prompt
237    messages: Vec<Message>,
238    /// Enable streaming of responses
239    stream: bool,
240    /// Optional tools to use for generation
241    #[serde(skip_serializing_if = "Option::is_none")]
242    tools: Option<Vec<Tool>>,
243    /// Maximum number of tokens to generate
244    #[serde(skip_serializing_if = "Option::is_none")]
245    max_tokens: Option<i32>,
246}
247
248/// A tool call in the response
249#[derive(Debug, Deserialize, Clone)]
250pub struct ToolCallResponse {
251    /// Unique identifier of the tool call
252    pub id: String,
253    /// Type of tool that was called
254    #[serde(rename = "type")]
255    pub r#type: ToolType,
256    /// Function that the LLM wants to call
257    pub function: ChatCompletionMessageToolCallFunction,
258}
259
260#[derive(Debug, Deserialize, Clone)]
261pub struct ChatCompletionChoice {
262    pub finish_reason: String,
263    pub message: Message,
264    pub index: i32,
265}
266
267/// The response from generating content
268#[derive(Debug, Deserialize, Clone)]
269pub struct CreateChatCompletionResponse {
270    pub id: String,
271    pub choices: Vec<ChatCompletionChoice>,
272    pub created: i64,
273    pub model: String,
274    pub object: String,
275}
276
277/// The response from streaming content generation
278#[derive(Debug, Deserialize, Clone)]
279pub struct CreateChatCompletionStreamResponse {
280    /// A unique identifier for the chat completion. Each chunk has the same ID.
281    pub id: String,
282    /// A list of chat completion choices. Can contain more than one element if `n` is greater than 1.
283    pub choices: Vec<ChatCompletionStreamChoice>,
284    /// The Unix timestamp (in seconds) of when the chat completion was created.
285    pub created: i64,
286    /// The model used to generate the completion.
287    pub model: String,
288    /// This fingerprint represents the backend configuration that the model runs with.
289    #[serde(skip_serializing_if = "Option::is_none")]
290    pub system_fingerprint: Option<String>,
291    /// The object type, which is always "chat.completion.chunk".
292    pub object: String,
293    /// Usage statistics for the completion request.
294    #[serde(skip_serializing_if = "Option::is_none")]
295    pub usage: Option<CompletionUsage>,
296}
297
298/// Choice in a streaming completion response
299#[derive(Debug, Deserialize, Clone)]
300pub struct ChatCompletionStreamChoice {
301    /// The delta content for this streaming chunk
302    pub delta: ChatCompletionStreamDelta,
303    /// Index of the choice in the choices array
304    pub index: i32,
305    /// The reason the model stopped generating tokens
306    #[serde(skip_serializing_if = "Option::is_none")]
307    pub finish_reason: Option<String>,
308}
309
310/// Delta content for streaming responses
311#[derive(Debug, Deserialize, Clone)]
312pub struct ChatCompletionStreamDelta {
313    /// Role of the message sender
314    #[serde(skip_serializing_if = "Option::is_none")]
315    pub role: Option<MessageRole>,
316    /// Content of the message delta
317    #[serde(skip_serializing_if = "Option::is_none")]
318    pub content: Option<String>,
319    /// Tool calls for this delta
320    #[serde(skip_serializing_if = "Option::is_none")]
321    pub tool_calls: Option<Vec<ToolCallResponse>>,
322}
323
324/// Usage statistics for the completion request
325#[derive(Debug, Deserialize, Clone)]
326pub struct CompletionUsage {
327    /// Number of tokens in the generated completion
328    pub completion_tokens: i64,
329    /// Number of tokens in the prompt
330    pub prompt_tokens: i64,
331    /// Total number of tokens used in the request (prompt + completion)
332    pub total_tokens: i64,
333}
334
335/// Client for interacting with the Inference Gateway API
336pub struct InferenceGatewayClient {
337    base_url: String,
338    client: Client,
339    token: Option<String>,
340    tools: Option<Vec<Tool>>,
341    max_tokens: Option<i32>,
342}
343
344/// Implement Debug for InferenceGatewayClient
345impl std::fmt::Debug for InferenceGatewayClient {
346    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
347        f.debug_struct("InferenceGatewayClient")
348            .field("base_url", &self.base_url)
349            .field("token", &self.token.as_ref().map(|_| "*****"))
350            .finish()
351    }
352}
353
354/// Core API interface for the Inference Gateway
355pub trait InferenceGatewayAPI {
356    /// Lists available models from all providers
357    ///
358    /// # Errors
359    /// - Returns [`GatewayError::Unauthorized`] if authentication fails
360    /// - Returns [`GatewayError::BadRequest`] if the request is malformed
361    /// - Returns [`GatewayError::InternalError`] if the server has an error
362    /// - Returns [`GatewayError::Other`] for other errors
363    ///
364    /// # Returns
365    /// A list of models available from all providers
366    fn list_models(&self) -> impl Future<Output = Result<ListModelsResponse, GatewayError>> + Send;
367
368    /// Lists available models by a specific provider
369    ///
370    /// # Arguments
371    /// * `provider` - The LLM provider to list models for
372    ///
373    /// # Errors
374    /// - Returns [`GatewayError::Unauthorized`] if authentication fails
375    /// - Returns [`GatewayError::BadRequest`] if the request is malformed
376    /// - Returns [`GatewayError::InternalError`] if the server has an error
377    /// - Returns [`GatewayError::Other`] for other errors
378    ///
379    /// # Returns
380    /// A list of models available from the specified provider
381    fn list_models_by_provider(
382        &self,
383        provider: Provider,
384    ) -> impl Future<Output = Result<ListModelsResponse, GatewayError>> + Send;
385
386    /// Generates content using a specified model
387    ///
388    /// # Arguments
389    /// * `provider` - The LLM provider to use
390    /// * `model` - Name of the model
391    /// * `messages` - Conversation history and prompt
392    /// * `tools` - Optional tools to use for generation
393    ///
394    /// # Errors
395    /// - Returns [`GatewayError::Unauthorized`] if authentication fails
396    /// - Returns [`GatewayError::BadRequest`] if the request is malformed
397    /// - Returns [`GatewayError::InternalError`] if the server has an error
398    /// - Returns [`GatewayError::Other`] for other errors
399    ///
400    /// # Returns
401    /// The generated response
402    fn generate_content(
403        &self,
404        provider: Provider,
405        model: &str,
406        messages: Vec<Message>,
407    ) -> impl Future<Output = Result<CreateChatCompletionResponse, GatewayError>> + Send;
408
409    /// Stream content generation directly using the backend SSE stream.
410    ///
411    /// # Arguments
412    /// * `provider` - The LLM provider to use
413    /// * `model` - Name of the model
414    /// * `messages` - Conversation history and prompt
415    ///
416    /// # Returns
417    /// A stream of Server-Sent Events (SSE) from the Inference Gateway API
418    fn generate_content_stream(
419        &self,
420        provider: Provider,
421        model: &str,
422        messages: Vec<Message>,
423    ) -> impl Stream<Item = Result<SSEvents, GatewayError>> + Send;
424
425    /// Checks if the API is available
426    fn health_check(&self) -> impl Future<Output = Result<bool, GatewayError>> + Send;
427}
428
429impl InferenceGatewayClient {
430    /// Creates a new client instance
431    ///
432    /// # Arguments
433    /// * `base_url` - Base URL of the Inference Gateway API
434    pub fn new(base_url: &str) -> Self {
435        Self {
436            base_url: base_url.to_string(),
437            client: Client::new(),
438            token: None,
439            tools: None,
440            max_tokens: None,
441        }
442    }
443
444    /// Creates a new client instance with default configuration
445    /// pointing to http://localhost:8080/v1
446    pub fn new_default() -> Self {
447        let base_url = std::env::var("INFERENCE_GATEWAY_URL")
448            .unwrap_or_else(|_| "http://localhost:8080/v1".to_string());
449
450        Self {
451            base_url,
452            client: Client::new(),
453            token: None,
454            tools: None,
455            max_tokens: None,
456        }
457    }
458
459    /// Returns the base URL this client is configured to use
460    pub fn base_url(&self) -> &str {
461        &self.base_url
462    }
463
464    /// Provides tools to use for generation
465    ///
466    /// # Arguments
467    /// * `tools` - List of tools to use for generation
468    ///
469    /// # Returns
470    /// Self with the tools set
471    pub fn with_tools(mut self, tools: Option<Vec<Tool>>) -> Self {
472        self.tools = tools;
473        self
474    }
475
476    /// Sets an authentication token for the client
477    ///
478    /// # Arguments
479    /// * `token` - JWT token for authentication
480    ///
481    /// # Returns
482    /// Self with the token set
483    pub fn with_token(mut self, token: impl Into<String>) -> Self {
484        self.token = Some(token.into());
485        self
486    }
487
488    /// Sets the maximum number of tokens to generate
489    ///
490    /// # Arguments
491    /// * `max_tokens` - Maximum number of tokens to generate
492    ///
493    /// # Returns
494    /// Self with the maximum tokens set
495    pub fn with_max_tokens(mut self, max_tokens: Option<i32>) -> Self {
496        self.max_tokens = max_tokens;
497        self
498    }
499}
500
501impl InferenceGatewayAPI for InferenceGatewayClient {
502    async fn list_models(&self) -> Result<ListModelsResponse, GatewayError> {
503        let url = format!("{}/models", self.base_url);
504        let mut request = self.client.get(&url);
505        if let Some(token) = &self.token {
506            request = request.bearer_auth(token);
507        }
508
509        let response = request.send().await?;
510        match response.status() {
511            StatusCode::OK => {
512                let json_response: ListModelsResponse = response.json().await?;
513                Ok(json_response)
514            }
515            StatusCode::UNAUTHORIZED => {
516                let error: ErrorResponse = response.json().await?;
517                Err(GatewayError::Unauthorized(error.error))
518            }
519            StatusCode::BAD_REQUEST => {
520                let error: ErrorResponse = response.json().await?;
521                Err(GatewayError::BadRequest(error.error))
522            }
523            StatusCode::INTERNAL_SERVER_ERROR => {
524                let error: ErrorResponse = response.json().await?;
525                Err(GatewayError::InternalError(error.error))
526            }
527            _ => Err(GatewayError::Other(Box::new(std::io::Error::new(
528                std::io::ErrorKind::Other,
529                format!("Unexpected status code: {}", response.status()),
530            )))),
531        }
532    }
533
534    async fn list_models_by_provider(
535        &self,
536        provider: Provider,
537    ) -> Result<ListModelsResponse, GatewayError> {
538        let url = format!("{}/models?provider={}", self.base_url, provider);
539        let mut request = self.client.get(&url);
540        if let Some(token) = &self.token {
541            request = self.client.get(&url).bearer_auth(token);
542        }
543
544        let response = request.send().await?;
545        match response.status() {
546            StatusCode::OK => {
547                let json_response: ListModelsResponse = response.json().await?;
548                Ok(json_response)
549            }
550            StatusCode::UNAUTHORIZED => {
551                let error: ErrorResponse = response.json().await?;
552                Err(GatewayError::Unauthorized(error.error))
553            }
554            StatusCode::BAD_REQUEST => {
555                let error: ErrorResponse = response.json().await?;
556                Err(GatewayError::BadRequest(error.error))
557            }
558            StatusCode::INTERNAL_SERVER_ERROR => {
559                let error: ErrorResponse = response.json().await?;
560                Err(GatewayError::InternalError(error.error))
561            }
562            _ => Err(GatewayError::Other(Box::new(std::io::Error::new(
563                std::io::ErrorKind::Other,
564                format!("Unexpected status code: {}", response.status()),
565            )))),
566        }
567    }
568
569    async fn generate_content(
570        &self,
571        provider: Provider,
572        model: &str,
573        messages: Vec<Message>,
574    ) -> Result<CreateChatCompletionResponse, GatewayError> {
575        let url = format!("{}/chat/completions?provider={}", self.base_url, provider);
576        let mut request = self.client.post(&url);
577        if let Some(token) = &self.token {
578            request = request.bearer_auth(token);
579        }
580
581        let request_payload = CreateChatCompletionRequest {
582            model: model.to_string(),
583            messages,
584            stream: false,
585            tools: self.tools.clone(),
586            max_tokens: self.max_tokens,
587        };
588
589        let response = request.json(&request_payload).send().await?;
590
591        match response.status() {
592            StatusCode::OK => Ok(response.json().await?),
593            StatusCode::BAD_REQUEST => {
594                let error: ErrorResponse = response.json().await?;
595                Err(GatewayError::BadRequest(error.error))
596            }
597            StatusCode::UNAUTHORIZED => {
598                let error: ErrorResponse = response.json().await?;
599                Err(GatewayError::Unauthorized(error.error))
600            }
601            StatusCode::INTERNAL_SERVER_ERROR => {
602                let error: ErrorResponse = response.json().await?;
603                Err(GatewayError::InternalError(error.error))
604            }
605            status => Err(GatewayError::Other(Box::new(std::io::Error::new(
606                std::io::ErrorKind::Other,
607                format!("Unexpected status code: {}", status),
608            )))),
609        }
610    }
611
612    /// Stream content generation directly using the backend SSE stream.
613    fn generate_content_stream(
614        &self,
615        provider: Provider,
616        model: &str,
617        messages: Vec<Message>,
618    ) -> impl Stream<Item = Result<SSEvents, GatewayError>> + Send {
619        let client = self.client.clone();
620        let base_url = self.base_url.clone();
621        let url = format!(
622            "{}/chat/completions?provider={}",
623            base_url,
624            provider.to_string().to_lowercase()
625        );
626
627        let request = CreateChatCompletionRequest {
628            model: model.to_string(),
629            messages,
630            stream: true,
631            tools: None,
632            max_tokens: None,
633        };
634
635        async_stream::try_stream! {
636            let response = client.post(&url).json(&request).send().await?;
637            let mut stream = response.bytes_stream();
638            let mut current_event: Option<String> = None;
639            let mut current_data: Option<String> = None;
640
641            while let Some(chunk) = stream.next().await {
642                let chunk = chunk?;
643                let chunk_str = String::from_utf8_lossy(&chunk);
644
645                for line in chunk_str.lines() {
646                    if line.is_empty() && current_data.is_some() {
647                        yield SSEvents {
648                            data: current_data.take().unwrap(),
649                            event: current_event.take(),
650                            retry: None, // TODO - implement this, for now it's not implemented in the backend
651                        };
652                        continue;
653                    }
654
655                    if let Some(event) = line.strip_prefix("event:") {
656                        current_event = Some(event.trim().to_string());
657                    } else if let Some(data) = line.strip_prefix("data:") {
658                        let processed_data = data.strip_suffix('\n').unwrap_or(data);
659                        current_data = Some(processed_data.trim().to_string());
660                    }
661                }
662            }
663        }
664    }
665
666    async fn health_check(&self) -> Result<bool, GatewayError> {
667        let url = format!("{}/health", self.base_url);
668
669        let response = self.client.get(&url).send().await?;
670        match response.status() {
671            StatusCode::OK => Ok(true),
672            _ => Ok(false),
673        }
674    }
675}
676
677#[cfg(test)]
678mod tests {
679    use crate::{
680        CreateChatCompletionRequest, CreateChatCompletionResponse,
681        CreateChatCompletionStreamResponse, FunctionObject, GatewayError, InferenceGatewayAPI,
682        InferenceGatewayClient, Message, MessageRole, Provider, Tool, ToolType,
683    };
684    use futures_util::{pin_mut, StreamExt};
685    use mockito::{Matcher, Server};
686    use serde_json::json;
687
688    #[test]
689    fn test_provider_serialization() {
690        let providers = vec![
691            (Provider::Ollama, "ollama"),
692            (Provider::Groq, "groq"),
693            (Provider::OpenAI, "openai"),
694            (Provider::Cloudflare, "cloudflare"),
695            (Provider::Cohere, "cohere"),
696            (Provider::Anthropic, "anthropic"),
697        ];
698
699        for (provider, expected) in providers {
700            let json = serde_json::to_string(&provider).unwrap();
701            assert_eq!(json, format!("\"{}\"", expected));
702        }
703    }
704
705    #[test]
706    fn test_provider_deserialization() {
707        let test_cases = vec![
708            ("\"ollama\"", Provider::Ollama),
709            ("\"groq\"", Provider::Groq),
710            ("\"openai\"", Provider::OpenAI),
711            ("\"cloudflare\"", Provider::Cloudflare),
712            ("\"cohere\"", Provider::Cohere),
713            ("\"anthropic\"", Provider::Anthropic),
714        ];
715
716        for (json, expected) in test_cases {
717            let provider: Provider = serde_json::from_str(json).unwrap();
718            assert_eq!(provider, expected);
719        }
720    }
721
722    #[test]
723    fn test_message_serialization_with_tool_call_id() {
724        let message_with_tool = Message {
725            role: MessageRole::Tool,
726            content: "The weather is sunny".to_string(),
727            tool_call_id: Some("call_123".to_string()),
728            ..Default::default()
729        };
730
731        let serialized = serde_json::to_string(&message_with_tool).unwrap();
732        let expected_with_tool =
733            r#"{"role":"tool","content":"The weather is sunny","tool_call_id":"call_123"}"#;
734        assert_eq!(serialized, expected_with_tool);
735
736        let message_without_tool = Message {
737            role: MessageRole::User,
738            content: "What's the weather?".to_string(),
739            ..Default::default()
740        };
741
742        let serialized = serde_json::to_string(&message_without_tool).unwrap();
743        let expected_without_tool = r#"{"role":"user","content":"What's the weather?"}"#;
744        assert_eq!(serialized, expected_without_tool);
745
746        let deserialized: Message = serde_json::from_str(expected_with_tool).unwrap();
747        assert_eq!(deserialized.role, MessageRole::Tool);
748        assert_eq!(deserialized.content, "The weather is sunny");
749        assert_eq!(deserialized.tool_call_id, Some("call_123".to_string()));
750
751        let deserialized: Message = serde_json::from_str(expected_without_tool).unwrap();
752        assert_eq!(deserialized.role, MessageRole::User);
753        assert_eq!(deserialized.content, "What's the weather?");
754        assert_eq!(deserialized.tool_call_id, None);
755    }
756
757    #[test]
758    fn test_provider_display() {
759        let providers = vec![
760            (Provider::Ollama, "ollama"),
761            (Provider::Groq, "groq"),
762            (Provider::OpenAI, "openai"),
763            (Provider::Cloudflare, "cloudflare"),
764            (Provider::Cohere, "cohere"),
765            (Provider::Anthropic, "anthropic"),
766        ];
767
768        for (provider, expected) in providers {
769            assert_eq!(provider.to_string(), expected);
770        }
771    }
772
773    #[test]
774    fn test_generate_request_serialization() {
775        let request_payload = CreateChatCompletionRequest {
776            model: "llama3.2:1b".to_string(),
777            messages: vec![
778                Message {
779                    role: MessageRole::System,
780                    content: "You are a helpful assistant.".to_string(),
781                    ..Default::default()
782                },
783                Message {
784                    role: MessageRole::User,
785                    content: "What is the current weather in Toronto?".to_string(),
786                    ..Default::default()
787                },
788            ],
789            stream: false,
790            tools: Some(vec![Tool {
791                r#type: ToolType::Function,
792                function: FunctionObject {
793                    name: "get_current_weather".to_string(),
794                    description: "Get the current weather of a city".to_string(),
795                    parameters: json!({
796                        "type": "object",
797                        "properties": {
798                            "city": {
799                                "type": "string",
800                                "description": "The name of the city"
801                            }
802                        },
803                        "required": ["city"]
804                    }),
805                },
806            }]),
807            max_tokens: None,
808        };
809
810        let serialized = serde_json::to_string_pretty(&request_payload).unwrap();
811        let expected = r#"{
812      "model": "llama3.2:1b",
813      "messages": [
814        {
815          "role": "system",
816          "content": "You are a helpful assistant."
817        },
818        {
819          "role": "user",
820          "content": "What is the current weather in Toronto?"
821        }
822      ],
823      "stream": false,
824      "tools": [
825        {
826          "type": "function",
827          "function": {
828            "name": "get_current_weather",
829            "description": "Get the current weather of a city",
830            "parameters": {
831              "type": "object",
832              "properties": {
833                "city": {
834                  "type": "string",
835                  "description": "The name of the city"
836                }
837              },
838              "required": ["city"]
839            }
840          }
841        }
842      ]
843    }"#;
844
845        assert_eq!(
846            serde_json::from_str::<serde_json::Value>(&serialized).unwrap(),
847            serde_json::from_str::<serde_json::Value>(expected).unwrap()
848        );
849    }
850
851    #[tokio::test]
852    async fn test_authentication_header() -> Result<(), GatewayError> {
853        let mut server = Server::new_async().await;
854
855        let mock_response = r#"{
856            "object": "list",
857            "data": []
858        }"#;
859
860        let mock_with_auth = server
861            .mock("GET", "/v1/models")
862            .match_header("authorization", "Bearer test-token")
863            .with_status(200)
864            .with_header("content-type", "application/json")
865            .with_body(mock_response)
866            .expect(1)
867            .create();
868
869        let base_url = format!("{}/v1", server.url());
870        let client = InferenceGatewayClient::new(&base_url).with_token("test-token");
871        client.list_models().await?;
872        mock_with_auth.assert();
873
874        let mock_without_auth = server
875            .mock("GET", "/v1/models")
876            .match_header("authorization", Matcher::Missing)
877            .with_status(200)
878            .with_header("content-type", "application/json")
879            .with_body(mock_response)
880            .expect(1)
881            .create();
882
883        let base_url = format!("{}/v1", server.url());
884        let client = InferenceGatewayClient::new(&base_url);
885        client.list_models().await?;
886        mock_without_auth.assert();
887
888        Ok(())
889    }
890
891    #[tokio::test]
892    async fn test_unauthorized_error() -> Result<(), GatewayError> {
893        let mut server = Server::new_async().await;
894
895        let raw_json_response = r#"{
896            "error": "Invalid token"
897        }"#;
898
899        let mock = server
900            .mock("GET", "/v1/models")
901            .with_status(401)
902            .with_header("content-type", "application/json")
903            .with_body(raw_json_response)
904            .create();
905
906        let base_url = format!("{}/v1", server.url());
907        let client = InferenceGatewayClient::new(&base_url);
908        let error = client.list_models().await.unwrap_err();
909
910        assert!(matches!(error, GatewayError::Unauthorized(_)));
911        if let GatewayError::Unauthorized(msg) = error {
912            assert_eq!(msg, "Invalid token");
913        }
914        mock.assert();
915
916        Ok(())
917    }
918
919    #[tokio::test]
920    async fn test_list_models() -> Result<(), GatewayError> {
921        let mut server = Server::new_async().await;
922
923        let raw_response_json = r#"{
924            "object": "list",
925            "data": [
926                {
927                    "id": "llama2",
928                    "object": "model",
929                    "created": 1630000001,
930                    "owned_by": "ollama",
931                    "served_by": "ollama"
932                }
933            ]
934        }"#;
935
936        let mock = server
937            .mock("GET", "/v1/models")
938            .with_status(200)
939            .with_header("content-type", "application/json")
940            .with_body(raw_response_json)
941            .create();
942
943        let base_url = format!("{}/v1", server.url());
944        let client = InferenceGatewayClient::new(&base_url);
945        let response = client.list_models().await?;
946
947        assert!(response.provider.is_none());
948        assert_eq!(response.object, "list");
949        assert_eq!(response.data.len(), 1);
950        assert_eq!(response.data[0].id, "llama2");
951        mock.assert();
952
953        Ok(())
954    }
955
956    #[tokio::test]
957    async fn test_list_models_by_provider() -> Result<(), GatewayError> {
958        let mut server = Server::new_async().await;
959
960        let raw_json_response = r#"{
961            "provider":"ollama",
962            "object":"list",
963            "data": [
964                {
965                    "id": "llama2",
966                    "object": "model",
967                    "created": 1630000001,
968                    "owned_by": "ollama",
969                    "served_by": "ollama"
970                }
971            ]
972        }"#;
973
974        let mock = server
975            .mock("GET", "/v1/models?provider=ollama")
976            .with_status(200)
977            .with_header("content-type", "application/json")
978            .with_body(raw_json_response)
979            .create();
980
981        let base_url = format!("{}/v1", server.url());
982        let client = InferenceGatewayClient::new(&base_url);
983        let response = client.list_models_by_provider(Provider::Ollama).await?;
984
985        assert!(response.provider.is_some());
986        assert_eq!(response.provider, Some(Provider::Ollama));
987        assert_eq!(response.data[0].id, "llama2");
988        mock.assert();
989
990        Ok(())
991    }
992
993    #[tokio::test]
994    async fn test_generate_content() -> Result<(), GatewayError> {
995        let mut server = Server::new_async().await;
996
997        let raw_json_response = r#"{
998            "id": "chatcmpl-456",
999            "object": "chat.completion",
1000            "created": 1630000001,
1001            "model": "mixtral-8x7b",
1002            "choices": [
1003                {
1004                    "index": 0,
1005                    "finish_reason": "stop",
1006                    "message": {
1007                        "role": "assistant",
1008                        "content": "Hellloooo"
1009                    }
1010                }
1011            ]
1012        }"#;
1013
1014        let mock = server
1015            .mock("POST", "/v1/chat/completions?provider=ollama")
1016            .with_status(200)
1017            .with_header("content-type", "application/json")
1018            .with_body(raw_json_response)
1019            .create();
1020
1021        let base_url = format!("{}/v1", server.url());
1022        let client = InferenceGatewayClient::new(&base_url);
1023
1024        let messages = vec![Message {
1025            role: MessageRole::User,
1026            content: "Hello".to_string(),
1027            ..Default::default()
1028        }];
1029        let response = client
1030            .generate_content(Provider::Ollama, "llama2", messages)
1031            .await?;
1032
1033        assert_eq!(response.choices[0].message.role, MessageRole::Assistant);
1034        assert_eq!(response.choices[0].message.content, "Hellloooo");
1035        mock.assert();
1036
1037        Ok(())
1038    }
1039
1040    #[tokio::test]
1041    async fn test_generate_content_serialization() -> Result<(), GatewayError> {
1042        let mut server = Server::new_async().await;
1043
1044        let raw_json = r#"{
1045            "id": "chatcmpl-456",
1046            "object": "chat.completion",
1047            "created": 1630000001,
1048            "model": "mixtral-8x7b",
1049            "choices": [
1050                {
1051                    "index": 0,
1052                    "finish_reason": "stop",
1053                    "message": {
1054                        "role": "assistant",
1055                        "content": "Hello"
1056                    }
1057                }
1058            ]
1059        }"#;
1060
1061        let mock = server
1062            .mock("POST", "/v1/chat/completions?provider=groq")
1063            .with_status(200)
1064            .with_header("content-type", "application/json")
1065            .with_body(raw_json)
1066            .create();
1067
1068        let base_url = format!("{}/v1", server.url());
1069        let client = InferenceGatewayClient::new(&base_url);
1070
1071        let direct_parse: Result<CreateChatCompletionResponse, _> = serde_json::from_str(raw_json);
1072        assert!(
1073            direct_parse.is_ok(),
1074            "Direct JSON parse failed: {:?}",
1075            direct_parse.err()
1076        );
1077
1078        let messages = vec![Message {
1079            role: MessageRole::User,
1080            content: "Hello".to_string(),
1081            ..Default::default()
1082        }];
1083
1084        let response = client
1085            .generate_content(Provider::Groq, "mixtral-8x7b", messages)
1086            .await?;
1087
1088        assert_eq!(response.choices[0].message.role, MessageRole::Assistant);
1089        assert_eq!(response.choices[0].message.content, "Hello");
1090
1091        mock.assert();
1092        Ok(())
1093    }
1094
1095    #[tokio::test]
1096    async fn test_generate_content_error_response() -> Result<(), GatewayError> {
1097        let mut server = Server::new_async().await;
1098
1099        let raw_json_response = r#"{
1100            "error":"Invalid request"
1101        }"#;
1102
1103        let mock = server
1104            .mock("POST", "/v1/chat/completions?provider=groq")
1105            .with_status(400)
1106            .with_header("content-type", "application/json")
1107            .with_body(raw_json_response)
1108            .create();
1109
1110        let base_url = format!("{}/v1", server.url());
1111        let client = InferenceGatewayClient::new(&base_url);
1112        let messages = vec![Message {
1113            role: MessageRole::User,
1114            content: "Hello".to_string(),
1115            ..Default::default()
1116        }];
1117        let error = client
1118            .generate_content(Provider::Groq, "mixtral-8x7b", messages)
1119            .await
1120            .unwrap_err();
1121
1122        assert!(matches!(error, GatewayError::BadRequest(_)));
1123        if let GatewayError::BadRequest(msg) = error {
1124            assert_eq!(msg, "Invalid request");
1125        }
1126        mock.assert();
1127
1128        Ok(())
1129    }
1130
1131    #[tokio::test]
1132    async fn test_gateway_errors() -> Result<(), GatewayError> {
1133        let mut server: mockito::ServerGuard = Server::new_async().await;
1134
1135        let unauthorized_mock = server
1136            .mock("GET", "/v1/models")
1137            .with_status(401)
1138            .with_header("content-type", "application/json")
1139            .with_body(r#"{"error":"Invalid token"}"#)
1140            .create();
1141
1142        let base_url = format!("{}/v1", server.url());
1143        let client = InferenceGatewayClient::new(&base_url);
1144        match client.list_models().await {
1145            Err(GatewayError::Unauthorized(msg)) => assert_eq!(msg, "Invalid token"),
1146            _ => panic!("Expected Unauthorized error"),
1147        }
1148        unauthorized_mock.assert();
1149
1150        let bad_request_mock = server
1151            .mock("GET", "/v1/models")
1152            .with_status(400)
1153            .with_header("content-type", "application/json")
1154            .with_body(r#"{"error":"Invalid provider"}"#)
1155            .create();
1156
1157        match client.list_models().await {
1158            Err(GatewayError::BadRequest(msg)) => assert_eq!(msg, "Invalid provider"),
1159            _ => panic!("Expected BadRequest error"),
1160        }
1161        bad_request_mock.assert();
1162
1163        let internal_error_mock = server
1164            .mock("GET", "/v1/models")
1165            .with_status(500)
1166            .with_header("content-type", "application/json")
1167            .with_body(r#"{"error":"Internal server error occurred"}"#)
1168            .create();
1169
1170        match client.list_models().await {
1171            Err(GatewayError::InternalError(msg)) => {
1172                assert_eq!(msg, "Internal server error occurred")
1173            }
1174            _ => panic!("Expected InternalError error"),
1175        }
1176        internal_error_mock.assert();
1177
1178        Ok(())
1179    }
1180
1181    #[tokio::test]
1182    async fn test_generate_content_case_insensitive() -> Result<(), GatewayError> {
1183        let mut server = Server::new_async().await;
1184
1185        let raw_json = r#"{
1186            "id": "chatcmpl-456",
1187            "object": "chat.completion",
1188            "created": 1630000001,
1189            "model": "mixtral-8x7b",
1190            "choices": [
1191                {
1192                    "index": 0,
1193                    "finish_reason": "stop",
1194                    "message": {
1195                        "role": "assistant",
1196                        "content": "Hello"
1197                    }
1198                }
1199            ]
1200        }"#;
1201
1202        let mock = server
1203            .mock("POST", "/v1/chat/completions?provider=groq")
1204            .with_status(200)
1205            .with_header("content-type", "application/json")
1206            .with_body(raw_json)
1207            .create();
1208
1209        let base_url = format!("{}/v1", server.url());
1210        let client = InferenceGatewayClient::new(&base_url);
1211
1212        let messages = vec![Message {
1213            role: MessageRole::User,
1214            content: "Hello".to_string(),
1215            ..Default::default()
1216        }];
1217
1218        let response = client
1219            .generate_content(Provider::Groq, "mixtral-8x7b", messages)
1220            .await?;
1221
1222        assert_eq!(response.choices[0].message.role, MessageRole::Assistant);
1223        assert_eq!(response.choices[0].message.content, "Hello");
1224        assert_eq!(response.model, "mixtral-8x7b");
1225        assert_eq!(response.object, "chat.completion");
1226        mock.assert();
1227
1228        Ok(())
1229    }
1230
1231    #[tokio::test]
1232    async fn test_generate_content_stream() -> Result<(), GatewayError> {
1233        let mut server = Server::new_async().await;
1234
1235        let mock = server
1236            .mock("POST", "/v1/chat/completions?provider=groq")
1237            .with_status(200)
1238            .with_header("content-type", "text/event-stream")
1239            .with_chunked_body(move |writer| -> std::io::Result<()> {
1240                let events = vec![
1241                    format!("data: {}\n\n", r#"{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"mixtral-8x7b","system_fingerprint":"fp_","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]}"#),
1242                    format!("data: {}\n\n", r#"{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268191,"model":"mixtral-8x7b","system_fingerprint":"fp_","choices":[{"index":0,"delta":{"role":"assistant","content":" World"},"finish_reason":null}]}"#),
1243                    format!("data: {}\n\n", r#"{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268192,"model":"mixtral-8x7b","system_fingerprint":"fp_","choices":[{"index":0,"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":17,"completion_tokens":40,"total_tokens":57}}"#),
1244                    format!("data: [DONE]\n\n")
1245                ];
1246                for event in events {
1247                    writer.write_all(event.as_bytes())?;
1248                }
1249                Ok(())
1250            })
1251            .create();
1252
1253        let base_url = format!("{}/v1", server.url());
1254        let client = InferenceGatewayClient::new(&base_url);
1255
1256        let messages = vec![Message {
1257            role: MessageRole::User,
1258            content: "Test message".to_string(),
1259            ..Default::default()
1260        }];
1261
1262        let stream = client.generate_content_stream(Provider::Groq, "mixtral-8x7b", messages);
1263        pin_mut!(stream);
1264        while let Some(result) = stream.next().await {
1265            let result = result?;
1266            let generate_response: CreateChatCompletionStreamResponse =
1267                serde_json::from_str(&result.data)
1268                    .expect("Failed to parse CreateChatCompletionResponse");
1269
1270            if generate_response.choices[0].finish_reason.is_some() {
1271                assert_eq!(
1272                    generate_response.choices[0].finish_reason.as_ref().unwrap(),
1273                    "stop"
1274                );
1275                break;
1276            }
1277
1278            if let Some(content) = &generate_response.choices[0].delta.content {
1279                assert!(matches!(content.as_str(), "Hello" | " World"));
1280            }
1281            if let Some(role) = &generate_response.choices[0].delta.role {
1282                assert_eq!(role, &MessageRole::Assistant);
1283            }
1284        }
1285
1286        mock.assert();
1287        Ok(())
1288    }
1289
1290    #[tokio::test]
1291    async fn test_generate_content_stream_error() -> Result<(), GatewayError> {
1292        let mut server = Server::new_async().await;
1293
1294        let mock = server
1295            .mock("POST", "/v1/chat/completions?provider=groq")
1296            .with_status(400)
1297            .with_header("content-type", "application/json")
1298            .with_chunked_body(move |writer| -> std::io::Result<()> {
1299                let events = vec![format!(
1300                    "event: {}\ndata: {}\nretry: {}\n\n",
1301                    r#"error"#, r#"{"error":"Invalid request"}"#, r#"1000"#,
1302                )];
1303                for event in events {
1304                    writer.write_all(event.as_bytes())?;
1305                }
1306                Ok(())
1307            })
1308            .expect_at_least(1)
1309            .create();
1310
1311        let base_url = format!("{}/v1", server.url());
1312        let client = InferenceGatewayClient::new(&base_url);
1313
1314        let messages = vec![Message {
1315            role: MessageRole::User,
1316            content: "Test message".to_string(),
1317            ..Default::default()
1318        }];
1319
1320        let stream = client.generate_content_stream(Provider::Groq, "mixtral-8x7b", messages);
1321
1322        pin_mut!(stream);
1323        while let Some(result) = stream.next().await {
1324            let result = result?;
1325            assert!(result.event.is_some());
1326            assert_eq!(result.event.unwrap(), "error");
1327            assert!(result.data.contains("Invalid request"));
1328            assert!(result.retry.is_none());
1329        }
1330
1331        mock.assert();
1332        Ok(())
1333    }
1334
1335    #[tokio::test]
1336    async fn test_generate_content_with_tools() -> Result<(), GatewayError> {
1337        let mut server = Server::new_async().await;
1338
1339        let raw_json_response = r#"{
1340            "id": "chatcmpl-123",
1341            "object": "chat.completion",
1342            "created": 1630000000,
1343            "model": "deepseek-r1-distill-llama-70b",
1344            "choices": [
1345                {
1346                    "index": 0,
1347                    "finish_reason": "tool_calls",
1348                    "message": {
1349                        "role": "assistant",
1350                        "content": "Let me check the weather for you.",
1351                        "tool_calls": [
1352                            {
1353                                "id": "1234",
1354                                "type": "function",
1355                                "function": {
1356                                    "name": "get_weather",
1357                                    "arguments": "{\"location\": \"London\"}"
1358                                }
1359                            }
1360                        ]
1361                    }
1362                }
1363            ]
1364        }"#;
1365
1366        let mock = server
1367            .mock("POST", "/v1/chat/completions?provider=groq")
1368            .with_status(200)
1369            .with_header("content-type", "application/json")
1370            .with_body(raw_json_response)
1371            .create();
1372
1373        let tools = vec![Tool {
1374            r#type: ToolType::Function,
1375            function: FunctionObject {
1376                name: "get_weather".to_string(),
1377                description: "Get the weather for a location".to_string(),
1378                parameters: json!({
1379                    "type": "object",
1380                    "properties": {
1381                        "location": {
1382                            "type": "string",
1383                            "description": "The city name"
1384                        }
1385                    },
1386                    "required": ["location"]
1387                }),
1388            },
1389        }];
1390
1391        let base_url = format!("{}/v1", server.url());
1392        let client = InferenceGatewayClient::new(&base_url).with_tools(Some(tools));
1393
1394        let messages = vec![Message {
1395            role: MessageRole::User,
1396            content: "What's the weather in London?".to_string(),
1397            ..Default::default()
1398        }];
1399
1400        let response = client
1401            .generate_content(Provider::Groq, "deepseek-r1-distill-llama-70b", messages)
1402            .await?;
1403
1404        assert_eq!(response.choices[0].message.role, MessageRole::Assistant);
1405        assert_eq!(
1406            response.choices[0].message.content,
1407            "Let me check the weather for you."
1408        );
1409
1410        let tool_calls = response.choices[0].message.tool_calls.as_ref().unwrap();
1411        assert_eq!(tool_calls.len(), 1);
1412        assert_eq!(tool_calls[0].function.name, "get_weather");
1413
1414        let params = tool_calls[0]
1415            .function
1416            .parse_arguments()
1417            .expect("Failed to parse function arguments");
1418        assert_eq!(params["location"].as_str().unwrap(), "London");
1419
1420        mock.assert();
1421        Ok(())
1422    }
1423
1424    #[tokio::test]
1425    async fn test_generate_content_without_tools() -> Result<(), GatewayError> {
1426        let mut server = Server::new_async().await;
1427
1428        let raw_json_response = r#"{
1429            "id": "chatcmpl-123",
1430            "object": "chat.completion",
1431            "created": 1630000000,
1432            "model": "gpt-4",
1433            "choices": [
1434                {
1435                    "index": 0,
1436                    "finish_reason": "stop",
1437                    "message": {
1438                        "role": "assistant",
1439                        "content": "Hello!"
1440                    }
1441                }
1442            ]
1443        }"#;
1444
1445        let mock = server
1446            .mock("POST", "/v1/chat/completions?provider=openai")
1447            .with_status(200)
1448            .with_header("content-type", "application/json")
1449            .with_body(raw_json_response)
1450            .create();
1451
1452        let base_url = format!("{}/v1", server.url());
1453        let client = InferenceGatewayClient::new(&base_url);
1454
1455        let messages = vec![Message {
1456            role: MessageRole::User,
1457            content: "Hi".to_string(),
1458            ..Default::default()
1459        }];
1460
1461        let response = client
1462            .generate_content(Provider::OpenAI, "gpt-4", messages)
1463            .await?;
1464
1465        assert_eq!(response.model, "gpt-4");
1466        assert_eq!(response.choices[0].message.content, "Hello!");
1467        assert_eq!(response.choices[0].message.role, MessageRole::Assistant);
1468        assert!(response.choices[0].message.tool_calls.is_none());
1469
1470        mock.assert();
1471        Ok(())
1472    }
1473
1474    #[tokio::test]
1475    async fn test_generate_content_with_tools_payload() -> Result<(), GatewayError> {
1476        let mut server = Server::new_async().await;
1477
1478        let raw_request_body = r#"{
1479            "model": "deepseek-r1-distill-llama-70b",
1480            "messages": [
1481                {
1482                    "role": "system",
1483                    "content": "You are a helpful assistant."
1484                },
1485                {
1486                    "role": "user",
1487                    "content": "What is the current weather in Toronto?"
1488                }
1489            ],
1490            "stream": false,
1491            "tools": [
1492                {
1493                    "type": "function",
1494                    "function": {
1495                        "name": "get_current_weather",
1496                        "description": "Get the current weather of a city",
1497                        "parameters": {
1498                            "type": "object",
1499                            "properties": {
1500                                "city": {
1501                                    "type": "string",
1502                                    "description": "The name of the city"
1503                                }
1504                            },
1505                            "required": ["city"]
1506                        }
1507                    }
1508                }
1509            ]
1510        }"#;
1511
1512        let raw_json_response = r#"{
1513            "id": "1234",
1514            "object": "chat.completion",
1515            "created": 1630000000,
1516            "model": "deepseek-r1-distill-llama-70b",
1517            "choices": [
1518                {
1519                    "index": 0,
1520                    "finish_reason": "stop",
1521                    "message": {
1522                        "role": "assistant",
1523                        "content": "Let me check the weather for you",
1524                        "tool_calls": [
1525                            {
1526                                "id": "1234",
1527                                "type": "function",
1528                                "function": {
1529                                    "name": "get_current_weather",
1530                                    "arguments": "{\"city\": \"Toronto\"}"
1531                                }
1532                            }
1533                        ]
1534                    }
1535                }
1536            ]
1537        }"#;
1538
1539        let mock = server
1540            .mock("POST", "/v1/chat/completions?provider=groq")
1541            .with_status(200)
1542            .with_header("content-type", "application/json")
1543            .match_body(mockito::Matcher::JsonString(raw_request_body.to_string()))
1544            .with_body(raw_json_response)
1545            .create();
1546
1547        let tools = vec![Tool {
1548            r#type: ToolType::Function,
1549            function: FunctionObject {
1550                name: "get_current_weather".to_string(),
1551                description: "Get the current weather of a city".to_string(),
1552                parameters: json!({
1553                    "type": "object",
1554                    "properties": {
1555                        "city": {
1556                            "type": "string",
1557                            "description": "The name of the city"
1558                        }
1559                    },
1560                    "required": ["city"]
1561                }),
1562            },
1563        }];
1564
1565        let base_url = format!("{}/v1", server.url());
1566        let client = InferenceGatewayClient::new(&base_url);
1567
1568        let messages = vec![
1569            Message {
1570                role: MessageRole::System,
1571                content: "You are a helpful assistant.".to_string(),
1572                ..Default::default()
1573            },
1574            Message {
1575                role: MessageRole::User,
1576                content: "What is the current weather in Toronto?".to_string(),
1577                ..Default::default()
1578            },
1579        ];
1580
1581        let response = client
1582            .with_tools(Some(tools))
1583            .generate_content(Provider::Groq, "deepseek-r1-distill-llama-70b", messages)
1584            .await?;
1585
1586        assert_eq!(response.choices[0].message.role, MessageRole::Assistant);
1587        assert_eq!(
1588            response.choices[0].message.content,
1589            "Let me check the weather for you"
1590        );
1591        assert_eq!(
1592            response.choices[0]
1593                .message
1594                .tool_calls
1595                .as_ref()
1596                .unwrap()
1597                .len(),
1598            1
1599        );
1600
1601        mock.assert();
1602        Ok(())
1603    }
1604
1605    #[tokio::test]
1606    async fn test_generate_content_with_max_tokens() -> Result<(), GatewayError> {
1607        let mut server = Server::new_async().await;
1608
1609        let raw_json_response = r#"{
1610            "id": "chatcmpl-123",
1611            "object": "chat.completion",
1612            "created": 1630000000,
1613            "model": "mixtral-8x7b",
1614            "choices": [
1615                {
1616                    "index": 0,
1617                    "finish_reason": "stop",
1618                    "message": {
1619                        "role": "assistant",
1620                        "content": "Here's a poem with 100 tokens..."
1621                    }
1622                }
1623            ]
1624        }"#;
1625
1626        let mock = server
1627            .mock("POST", "/v1/chat/completions?provider=groq")
1628            .with_status(200)
1629            .with_header("content-type", "application/json")
1630            .match_body(mockito::Matcher::JsonString(
1631                r#"{
1632                "model": "mixtral-8x7b",
1633                "messages": [{"role":"user","content":"Write a poem"}],
1634                "stream": false,
1635                "max_tokens": 100
1636            }"#
1637                .to_string(),
1638            ))
1639            .with_body(raw_json_response)
1640            .create();
1641
1642        let base_url = format!("{}/v1", server.url());
1643        let client = InferenceGatewayClient::new(&base_url).with_max_tokens(Some(100));
1644
1645        let messages = vec![Message {
1646            role: MessageRole::User,
1647            content: "Write a poem".to_string(),
1648            ..Default::default()
1649        }];
1650
1651        let response = client
1652            .generate_content(Provider::Groq, "mixtral-8x7b", messages)
1653            .await?;
1654
1655        assert_eq!(
1656            response.choices[0].message.content,
1657            "Here's a poem with 100 tokens..."
1658        );
1659        assert_eq!(response.model, "mixtral-8x7b");
1660        assert_eq!(response.created, 1630000000);
1661        assert_eq!(response.object, "chat.completion");
1662
1663        mock.assert();
1664        Ok(())
1665    }
1666
1667    #[tokio::test]
1668    async fn test_health_check() -> Result<(), GatewayError> {
1669        let mut server = Server::new_async().await;
1670        let mock = server.mock("GET", "/health").with_status(200).create();
1671
1672        let client = InferenceGatewayClient::new(&server.url());
1673        let is_healthy = client.health_check().await?;
1674
1675        assert!(is_healthy);
1676        mock.assert();
1677
1678        Ok(())
1679    }
1680
1681    #[tokio::test]
1682    async fn test_client_base_url_configuration() -> Result<(), GatewayError> {
1683        let mut custom_url_server = Server::new_async().await;
1684
1685        let custom_url_mock = custom_url_server
1686            .mock("GET", "/health")
1687            .with_status(200)
1688            .create();
1689
1690        let custom_client = InferenceGatewayClient::new(&custom_url_server.url());
1691        let is_healthy = custom_client.health_check().await?;
1692        assert!(is_healthy);
1693        custom_url_mock.assert();
1694
1695        let default_client = InferenceGatewayClient::new_default();
1696
1697        let default_url = "http://localhost:8080/v1";
1698        assert_eq!(default_client.base_url(), default_url);
1699
1700        Ok(())
1701    }
1702}