use crate::client::{CompletionRequest, LlmClient, Message, Role};
use crate::error::Error;
use crate::schema;
pub struct CompleteOptions {
pub max_tokens: u32,
pub system: Option<String>,
pub model_override: Option<String>,
}
impl Default for CompleteOptions {
fn default() -> Self {
Self {
max_tokens: 4096,
system: None,
model_override: None,
}
}
}
pub async fn complete_with<T>(
client: &dyn LlmClient,
prompt: &str,
opts: CompleteOptions,
) -> Result<T, Error>
where
T: schemars::JsonSchema + serde::de::DeserializeOwned,
{
let raw_schema = serde_json::to_value(schemars::schema_for!(T))
.map_err(|e| Error::SchemaError(format!("schema_for serialization: {e}")))?;
let normalized = schema::for_structured_output(raw_schema);
let request = CompletionRequest {
system: opts.system,
messages: vec![Message {
role: Role::User,
content: prompt.to_string(),
tool_call_id: None,
}],
max_tokens: opts.max_tokens,
model_override: opts.model_override,
schema: Some(normalized),
tools: None,
tool_choice: None,
};
let text = client.complete(request).await?;
serde_json::from_str::<T>(&text).map_err(|e| Error::Deserialization(e.to_string()))
}
pub async fn complete<T>(client: &dyn LlmClient, prompt: &str) -> Result<T, Error>
where
T: schemars::JsonSchema + serde::de::DeserializeOwned,
{
complete_with(client, prompt, CompleteOptions::default()).await
}
#[cfg(test)]
mod tests {
use super::*;
use async_trait::async_trait;
use schemars::JsonSchema;
use serde::Deserialize;
use std::sync::Mutex;
use crate::client::{CompletionRequest, TokenStream};
#[derive(Debug, Deserialize, JsonSchema, PartialEq)]
struct MyOutput {
value: String,
}
#[derive(Debug, Deserialize, JsonSchema, PartialEq)]
struct SimpleStruct {
value: i64,
}
struct ConstClient(String);
#[async_trait]
impl LlmClient for ConstClient {
fn default_model(&self) -> &str {
"test"
}
async fn complete(&self, _: CompletionRequest) -> Result<String, Error> {
Ok(self.0.clone())
}
async fn complete_stream(&self, _: CompletionRequest) -> Result<TokenStream, Error> {
Err(Error::Unsupported)
}
async fn embed(&self, _: &str) -> Result<Vec<f32>, Error> {
Err(Error::Unsupported)
}
}
struct CapturingClient {
response: String,
captured: Mutex<Option<CompletionRequest>>,
}
impl CapturingClient {
fn new(response: &str) -> Self {
Self {
response: response.to_string(),
captured: Mutex::new(None),
}
}
}
#[async_trait]
impl LlmClient for CapturingClient {
fn default_model(&self) -> &str {
"test"
}
async fn complete(&self, req: CompletionRequest) -> Result<String, Error> {
*self.captured.lock().unwrap() = Some(req);
Ok(self.response.clone())
}
async fn complete_stream(&self, _: CompletionRequest) -> Result<TokenStream, Error> {
Err(Error::Unsupported)
}
async fn embed(&self, _: &str) -> Result<Vec<f32>, Error> {
Err(Error::Unsupported)
}
}
#[tokio::test]
async fn complete_returns_typed_result() {
let client = ConstClient(r#"{"value":"hello"}"#.to_string());
let result = complete::<MyOutput>(&client, "test prompt").await.unwrap();
assert_eq!(result.value, "hello");
}
#[tokio::test]
async fn complete_propagates_deserialization_error() {
let client = ConstClient(r#"{"wrong_field":"hello"}"#.to_string());
let result = complete::<MyOutput>(&client, "test prompt").await;
match result {
Err(Error::Deserialization(_)) => {}
other => panic!("expected Deserialization error, got: {other:?}"),
}
}
#[test]
fn complete_options_default() {
let opts = CompleteOptions::default();
assert_eq!(opts.max_tokens, 4096);
assert!(opts.system.is_none());
assert!(opts.model_override.is_none());
}
#[tokio::test]
async fn complete_with_uses_provided_max_tokens() {
let client = CapturingClient::new(r#"{"value":1}"#);
let opts = CompleteOptions {
max_tokens: 9999,
system: Some("sys".to_string()),
model_override: Some("m".to_string()),
};
let _: SimpleStruct = complete_with(&client, "p", opts).await.unwrap();
let req = client.captured.lock().unwrap().take().unwrap();
assert_eq!(req.max_tokens, 9999);
assert_eq!(req.system, Some("sys".to_string()));
assert_eq!(req.model_override, Some("m".to_string()));
assert!(req.schema.is_some());
}
#[tokio::test]
async fn complete_delegates_to_complete_with() {
let client = CapturingClient::new(r#"{"value":1}"#);
let _: SimpleStruct = complete(&client, "p").await.unwrap();
let req = client.captured.lock().unwrap().take().unwrap();
assert_eq!(req.max_tokens, 4096);
assert!(req.system.is_none());
assert!(req.model_override.is_none());
assert!(req.schema.is_some());
}
}