use crate::client::Otari;
use crate::config::Config;
use crate::error::Result;
use crate::types::{
ChatCompletion, CompletionParams, CompletionStream, Message, ReasoningEffort, RerankParams,
RerankResponse, StopSequence, Tool, ToolChoice,
};
use serde_json::Value;
#[derive(Debug, Clone, Default)]
pub struct CompletionOptions {
pub api_key: Option<String>,
pub api_base: Option<String>,
pub tools: Option<Vec<Tool>>,
pub tool_choice: Option<ToolChoice>,
pub temperature: Option<f32>,
pub top_p: Option<f32>,
pub max_tokens: Option<u32>,
pub n: Option<u32>,
pub stop: Option<StopSequence>,
pub presence_penalty: Option<f32>,
pub frequency_penalty: Option<f32>,
pub seed: Option<i64>,
pub user: Option<String>,
pub parallel_tool_calls: Option<bool>,
pub logprobs: Option<bool>,
pub top_logprobs: Option<u32>,
pub logit_bias: Option<std::collections::HashMap<String, f32>>,
pub response_format: Option<Value>,
pub reasoning_effort: Option<ReasoningEffort>,
}
impl CompletionOptions {
pub fn with_api_key(api_key: impl Into<String>) -> Self {
Self {
api_key: Some(api_key.into()),
..Default::default()
}
}
pub fn api_base(mut self, api_base: impl Into<String>) -> Self {
self.api_base = Some(api_base.into());
self
}
pub fn temperature(mut self, temperature: f32) -> Self {
self.temperature = Some(temperature);
self
}
pub fn max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = Some(max_tokens);
self
}
pub fn tools(mut self, tools: Vec<Tool>) -> Self {
self.tools = Some(tools);
self
}
pub fn tool_choice(mut self, tool_choice: ToolChoice) -> Self {
self.tool_choice = Some(tool_choice);
self
}
pub fn reasoning_effort(mut self, effort: ReasoningEffort) -> Self {
self.reasoning_effort = Some(effort);
self
}
}
impl From<CompletionOptions> for Config {
fn from(options: CompletionOptions) -> Self {
Config {
api_key: options.api_key,
api_base: options.api_base,
extra: Default::default(),
}
}
}
pub async fn completion(
model: &str,
messages: Vec<Message>,
options: CompletionOptions,
) -> Result<ChatCompletion> {
let model_id = model.to_string();
let client = Otari::from_config(options.clone().into())?;
let params = CompletionParams {
model_id,
messages,
tools: options.tools,
tool_choice: options.tool_choice,
temperature: options.temperature,
top_p: options.top_p,
max_tokens: options.max_tokens,
stream: Some(false),
n: options.n,
stop: options.stop,
presence_penalty: options.presence_penalty,
frequency_penalty: options.frequency_penalty,
seed: options.seed,
user: options.user,
parallel_tool_calls: options.parallel_tool_calls,
logprobs: options.logprobs,
top_logprobs: options.top_logprobs,
logit_bias: options.logit_bias,
response_format: options.response_format,
reasoning_effort: options.reasoning_effort,
};
client.completion(params).await
}
pub async fn completion_stream(
model: &str,
messages: Vec<Message>,
options: CompletionOptions,
) -> Result<CompletionStream> {
let model_id = model.to_string();
let client = Otari::from_config(options.clone().into())?;
let params = CompletionParams {
model_id,
messages,
tools: options.tools,
tool_choice: options.tool_choice,
temperature: options.temperature,
top_p: options.top_p,
max_tokens: options.max_tokens,
stream: Some(true),
n: options.n,
stop: options.stop,
presence_penalty: options.presence_penalty,
frequency_penalty: options.frequency_penalty,
seed: options.seed,
user: options.user,
parallel_tool_calls: options.parallel_tool_calls,
logprobs: options.logprobs,
top_logprobs: options.top_logprobs,
logit_bias: options.logit_bias,
response_format: options.response_format,
reasoning_effort: options.reasoning_effort,
};
client.completion_stream(params).await
}
#[derive(Debug, Clone, Default)]
pub struct RerankOptions {
pub api_key: Option<String>,
pub api_base: Option<String>,
pub top_n: Option<u32>,
pub max_tokens_per_doc: Option<u32>,
pub user: Option<String>,
}
impl From<RerankOptions> for Config {
fn from(options: RerankOptions) -> Self {
Config {
api_key: options.api_key,
api_base: options.api_base,
extra: Default::default(),
}
}
}
pub async fn rerank(
model: &str,
query: &str,
documents: Vec<String>,
options: RerankOptions,
) -> Result<RerankResponse> {
let client = Otari::from_config(options.clone().into())?;
let params = RerankParams {
model_id: model.to_string(),
query: query.to_string(),
documents,
top_n: options.top_n,
max_tokens_per_doc: options.max_tokens_per_doc,
user: options.user,
};
client.rerank(params).await
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_completion_options_builder() {
let options = CompletionOptions::with_api_key("test-key")
.temperature(0.7)
.max_tokens(100);
assert_eq!(options.api_key, Some("test-key".to_string()));
assert_eq!(options.temperature, Some(0.7));
assert_eq!(options.max_tokens, Some(100));
}
}