use crate::client::ModelClient;
use crate::error::{OllamaError, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: String,
pub content: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ChatCompletionsRequest {
pub model: String,
pub messages: Vec<ChatMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_format: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub seed: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream_options: Option<StreamOptions>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<serde_json::Value>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning_effort: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub logit_bias: Option<HashMap<String, f32>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub n: Option<u32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StreamOptions {
#[serde(skip_serializing_if = "Option::is_none")]
pub include_usage: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Choice {
pub index: u32,
pub message: ChatMessage,
pub finish_reason: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionsResponse {
pub id: String,
pub choices: Vec<Choice>,
pub created: u64,
pub model: String,
pub usage: Usage,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum OpenAIEmbeddingsInput {
Single(String),
Multiple(Vec<String>),
Tokens(Vec<u32>),
TokenArrays(Vec<Vec<u32>>),
}
impl Default for OpenAIEmbeddingsInput {
fn default() -> Self {
Self::Single(String::new())
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct OpenAIEmbeddingsRequest {
pub model: String,
pub input: OpenAIEmbeddingsInput,
#[serde(skip_serializing_if = "Option::is_none")]
pub encoding_format: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub dimensions: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OpenAIEmbedding {
pub embedding: Vec<f32>,
pub index: u32,
pub object: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OpenAIEmbeddingsResponse {
pub data: Vec<OpenAIEmbedding>,
pub model: String,
pub usage: Usage,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ResponsesRequest {
pub model: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub input: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub instructions: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<serde_json::Value>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_output_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub previous_response_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub conversation: Option<Vec<serde_json::Value>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub truncation: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResponsesResponse {
pub output: String,
pub done: bool,
pub model: String,
pub done_reason: String,
pub tool_calls: Vec<serde_json::Value>,
pub prompt_evals: u32,
pub eval_count: u32,
pub total_duration: u64,
pub load_duration: u64,
pub prompt_eval_duration: u64,
pub eval_duration: u64,
pub output_eval_count: u32,
pub output_eval_duration: u64,
}
impl ModelClient {
pub async fn chat_completions(
&self,
request: ChatCompletionsRequest,
) -> Result<ChatCompletionsResponse> {
let url = self
.base_url
.join("v1/chat/completions")
.map_err(OllamaError::UrlError)?;
let response = self
.client
.post(url)
.json(&request)
.send()
.await
.map_err(OllamaError::RequestError)?;
self.handle_response(response, Some(&request.model)).await
}
pub async fn openai_embeddings(
&self,
request: OpenAIEmbeddingsRequest,
) -> Result<OpenAIEmbeddingsResponse> {
let url = self
.base_url
.join("v1/embeddings")
.map_err(OllamaError::UrlError)?;
let response = self
.client
.post(url)
.json(&request)
.send()
.await
.map_err(OllamaError::RequestError)?;
self.handle_response(response, Some(&request.model)).await
}
pub async fn responses(&self, request: ResponsesRequest) -> Result<ResponsesResponse> {
let url = self
.base_url
.join("v1/responses")
.map_err(OllamaError::UrlError)?;
let response = self
.client
.post(url)
.json(&request)
.send()
.await
.map_err(OllamaError::RequestError)?;
self.handle_response(response, Some(&request.model)).await
}
}