use reqwest::{blocking::Client, StatusCode};
use serde::{Deserialize, Serialize};
use std::{error::Error, fmt};
#[derive(Debug)]
pub enum GatewayError {
Unauthorized(String),
BadRequest(String),
InternalError(String),
RequestError(reqwest::Error),
Other(Box<dyn Error + Send + Sync>),
}
impl fmt::Display for GatewayError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Unauthorized(msg) => write!(f, "Unauthorized: {}", msg),
Self::BadRequest(msg) => write!(f, "Bad request: {}", msg),
Self::InternalError(msg) => write!(f, "Internal server error: {}", msg),
Self::RequestError(e) => write!(f, "Request error: {}", e),
Self::Other(e) => write!(f, "Other error: {}", e),
}
}
}
impl Error for GatewayError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
Self::RequestError(e) => Some(e),
Self::Other(e) => Some(e.as_ref()),
_ => None,
}
}
}
impl From<reqwest::Error> for GatewayError {
fn from(err: reqwest::Error) -> Self {
Self::RequestError(err)
}
}
#[derive(Debug, Deserialize)]
struct ErrorResponse {
error: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Model {
pub id: String,
pub object: String,
pub owned_by: String,
pub created: i64,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ProviderModels {
pub provider: Provider,
pub models: Vec<Model>,
}
#[derive(Debug, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum Provider {
Ollama,
Groq,
OpenAI,
Google,
Cloudflare,
Cohere,
Anthropic,
}
impl fmt::Display for Provider {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Provider::Ollama => write!(f, "ollama"),
Provider::Groq => write!(f, "groq"),
Provider::OpenAI => write!(f, "openai"),
Provider::Google => write!(f, "google"),
Provider::Cloudflare => write!(f, "cloudflare"),
Provider::Cohere => write!(f, "cohere"),
Provider::Anthropic => write!(f, "anthropic"),
}
}
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum MessageRole {
System,
User,
Assistant,
}
impl fmt::Display for MessageRole {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
MessageRole::System => write!(f, "system"),
MessageRole::User => write!(f, "user"),
MessageRole::Assistant => write!(f, "assistant"),
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Message {
pub role: MessageRole,
pub content: String,
}
#[derive(Debug, Serialize)]
struct GenerateRequest {
model: String,
messages: Vec<Message>,
}
#[derive(Debug, Deserialize)]
pub struct GenerateResponse {
pub provider: Provider,
pub response: ResponseContent,
}
#[derive(Debug, Deserialize)]
pub struct ResponseContent {
pub role: MessageRole,
pub model: String,
pub content: String,
}
pub struct InferenceGatewayClient {
base_url: String,
client: Client,
token: Option<String>,
}
pub trait InferenceGatewayAPI {
fn list_models(&self) -> Result<Vec<ProviderModels>, GatewayError>;
fn list_models_by_provider(&self, provider: Provider) -> Result<ProviderModels, GatewayError>;
fn generate_content(
&self,
provider: Provider,
model: &str,
messages: Vec<Message>,
) -> Result<GenerateResponse, GatewayError>;
fn health_check(&self) -> Result<bool, Box<dyn Error>>;
}
impl InferenceGatewayClient {
pub fn new(base_url: &str) -> Self {
Self {
base_url: base_url.to_string(),
client: Client::new(),
token: None,
}
}
pub fn with_token(mut self, token: impl Into<String>) -> Self {
self.token = Some(token.into());
self
}
}
impl InferenceGatewayAPI for InferenceGatewayClient {
fn list_models(&self) -> Result<Vec<ProviderModels>, GatewayError> {
let url = format!("{}/llms", self.base_url);
let mut request = self.client.get(&url);
if let Some(token) = &self.token {
request = request.bearer_auth(token);
}
let response = request.send()?;
match response.status() {
StatusCode::OK => Ok(response.json()?),
StatusCode::UNAUTHORIZED => {
let error: ErrorResponse = response.json()?;
Err(GatewayError::Unauthorized(error.error))
}
StatusCode::BAD_REQUEST => {
let error: ErrorResponse = response.json()?;
Err(GatewayError::BadRequest(error.error))
}
StatusCode::INTERNAL_SERVER_ERROR => {
let error: ErrorResponse = response.json()?;
Err(GatewayError::InternalError(error.error))
}
_ => Err(GatewayError::Other(Box::new(std::io::Error::new(
std::io::ErrorKind::Other,
format!("Unexpected status code: {}", response.status()),
)))),
}
}
fn list_models_by_provider(&self, provider: Provider) -> Result<ProviderModels, GatewayError> {
let url = format!("{}/llms/{}", self.base_url, provider);
let mut request = self.client.get(&url);
if let Some(token) = &self.token {
request = self.client.get(&url).bearer_auth(token);
}
let response = request.send()?;
match response.status() {
StatusCode::OK => Ok(response.json()?),
StatusCode::UNAUTHORIZED => {
let error: ErrorResponse = response.json()?;
Err(GatewayError::Unauthorized(error.error))
}
StatusCode::BAD_REQUEST => {
let error: ErrorResponse = response.json()?;
Err(GatewayError::BadRequest(error.error))
}
StatusCode::INTERNAL_SERVER_ERROR => {
let error: ErrorResponse = response.json()?;
Err(GatewayError::InternalError(error.error))
}
_ => Err(GatewayError::Other(Box::new(std::io::Error::new(
std::io::ErrorKind::Other,
format!("Unexpected status code: {}", response.status()),
)))),
}
}
fn generate_content(
&self,
provider: Provider,
model: &str,
messages: Vec<Message>,
) -> Result<GenerateResponse, GatewayError> {
let url = format!("{}/llms/{}/generate", self.base_url, provider);
let mut request = self.client.post(&url);
if let Some(token) = &self.token {
request = request.bearer_auth(token);
}
let request_payload = GenerateRequest {
model: model.to_string(),
messages,
};
let response = request.json(&request_payload).send()?;
match response.status() {
StatusCode::OK => Ok(response.json()?),
StatusCode::UNAUTHORIZED => {
let error: ErrorResponse = response.json()?;
Err(GatewayError::Unauthorized(error.error))
}
StatusCode::BAD_REQUEST => {
let error: ErrorResponse = response.json()?;
Err(GatewayError::BadRequest(error.error))
}
StatusCode::INTERNAL_SERVER_ERROR => {
let error: ErrorResponse = response.json()?;
Err(GatewayError::InternalError(error.error))
}
_ => Err(GatewayError::Other(Box::new(std::io::Error::new(
std::io::ErrorKind::Other,
format!("Unexpected status code: {}", response.status()),
)))),
}
}
fn health_check(&self) -> Result<bool, Box<dyn Error>> {
let url = format!("{}/health", self.base_url);
let response = self.client.get(&url).send()?;
Ok(response.status().is_success())
}
}
#[cfg(test)]
mod tests {
use super::*;
use mockito::{Matcher, Server};
#[test]
fn test_authentication_header() {
let mut server = Server::new();
let mock_with_auth = server
.mock("GET", "/llms")
.match_header("authorization", "Bearer test-token")
.with_status(200)
.with_header("content-type", "application/json")
.with_body("[]")
.expect(1)
.create();
let client = InferenceGatewayClient::new(&server.url()).with_token("test-token");
client.list_models().unwrap();
mock_with_auth.assert();
let mock_without_auth = server
.mock("GET", "/llms")
.match_header("authorization", Matcher::Missing)
.with_status(200)
.with_header("content-type", "application/json")
.with_body("[]")
.expect(1)
.create();
let client = InferenceGatewayClient::new(&server.url());
client.list_models().unwrap();
mock_without_auth.assert();
}
#[test]
fn test_unauthorized_error() {
let mut server = Server::new();
let mock = server
.mock("GET", "/llms")
.with_status(401)
.with_header("content-type", "application/json")
.with_body(r#"{"error":"Invalid token"}"#)
.create();
let client = InferenceGatewayClient::new(&server.url());
let error = client.list_models().unwrap_err();
assert!(matches!(error, GatewayError::Unauthorized(_)));
if let GatewayError::Unauthorized(msg) = error {
assert_eq!(msg, "Invalid token");
}
mock.assert();
}
#[test]
fn test_list_models() {
let mut server = Server::new();
let mock = server
.mock("GET", "/llms")
.with_status(200)
.with_header("content-type", "application/json")
.with_body(r#"[{"provider":"ollama","models":[{"id":"llama2","object":"model","owned_by":"meta","created":1600000000}]}]"#)
.create();
let client = InferenceGatewayClient::new(&server.url());
let models = client.list_models().unwrap();
assert_eq!(models.len(), 1);
assert_eq!(models[0].models[0].id, "llama2");
mock.assert();
}
#[test]
fn test_get_provider_models() {
let mut server = Server::new();
let mock = server
.mock("GET", "/llms/ollama")
.with_status(200)
.with_header("content-type", "application/json")
.with_body(r#"{"provider":"ollama","models":[{"id":"llama2","object":"model","owned_by":"meta","created":1600000000}]}"#)
.create();
let client = InferenceGatewayClient::new(&server.url());
let models = client.list_models_by_provider(Provider::Ollama).unwrap();
assert_eq!(models.provider, Provider::Ollama);
assert_eq!(models.models[0].id, "llama2");
mock.assert();
}
#[test]
fn test_generate_content() {
let mut server = Server::new();
let mock = server
.mock("POST", "/llms/ollama/generate")
.with_status(200)
.with_header("content-type", "application/json")
.with_body(r#"{"provider":"ollama","response":{"role":"assistant","model":"llama2","content":"Hellloooo"}}"#)
.create();
let client = InferenceGatewayClient::new(&server.url());
let messages = vec![Message {
role: MessageRole::User,
content: "Hello".to_string(),
}];
let response = client
.generate_content(Provider::Ollama, "llama2", messages)
.unwrap();
assert_eq!(response.provider, Provider::Ollama);
assert_eq!(response.response.role, MessageRole::Assistant);
assert_eq!(response.response.model, "llama2");
assert_eq!(response.response.content, "Hellloooo");
mock.assert();
}
#[test]
fn test_health_check() {
let mut server = Server::new();
let mock = server.mock("GET", "/health").with_status(200).create();
let client = InferenceGatewayClient::new(&server.url());
let is_healthy = client.health_check().unwrap();
assert!(is_healthy);
mock.assert();
}
}