ferro_ai/complete.rs
1//! Typed completion entry point for structured LLM output.
2//!
3//! [`complete`] is the primary surface of the ferro-ai SDK for structured-output use.
4//! Callers never import `schemars` or `serde_json` directly — schema generation,
5//! normalization, and JSON parsing are fully encapsulated (SC#1, D-01).
6//!
7//! ## Usage
8//!
9//! ```rust,ignore
10//! use ferro_ai::{complete, AnthropicClient};
11//! use serde::Deserialize;
12//! use schemars::JsonSchema;
13//!
14//! #[derive(Deserialize, JsonSchema)]
15//! struct OrderSummary { name: String, total: f64 }
16//!
17//! let client = AnthropicClient::from_env().unwrap();
18//! let summary: OrderSummary = complete(&client, "Summarize order #42 as JSON").await?;
19//! ```
20//!
21//! ## Internal flow
22//!
23//! 1. `schemars::schema_for::<T>()` — generate Draft 2020-12 schema from the Rust type.
24//! 2. `schema::for_structured_output(raw)` — normalize for Anthropic/OpenAI constraints;
25//! activates the ServiceDef-aware projection-enum closing path when T contains
26//! ferro-projections types in its `$defs` (D-07).
27//! 3. Build `CompletionRequest` with `schema: Some(normalized)`.
28//! 4. `client.complete(request)` — delegate to the configured LLM provider.
29//! 5. `serde_json::from_str::<T>(&text)` — deserialize the JSON response into T.
30//!
31//! ## Plan 04 dependency note
32//!
33//! The `CompletionRequest` struct literal in this file lists exactly the five fields
34//! that exist after Plan 02 (Phase 165): `system`, `messages`, `max_tokens`,
35//! `model_override`, `schema`. Plan 04 adds `tools: Option<Vec<ToolRequest>>` and
36//! `tool_choice: Option<ToolChoice>` — when those fields land, Plan 04 is responsible
37//! for updating this struct literal to add `tools: None, tool_choice: None` (or
38//! restructuring via `Default` if that derive is added).
39
40use crate::client::{CompletionRequest, LlmClient, Message, Role};
41use crate::error::Error;
42use crate::schema;
43
44/// Options controlling a typed completion request.
45///
46/// `max_tokens` caps the response; callers map `FERRO_AI_MAX_TOKENS_PER_COMMAND` onto it.
47/// `system` supplies an optional system prompt for context-heavy completions.
48/// `model_override` selects a non-default model for this request only.
49pub struct CompleteOptions {
50 /// Maximum number of tokens in the completion response.
51 pub max_tokens: u32,
52 /// Optional system prompt prepended before the user message.
53 pub system: Option<String>,
54 /// Override the provider's default model for this request.
55 pub model_override: Option<String>,
56}
57
58impl Default for CompleteOptions {
59 fn default() -> Self {
60 Self {
61 max_tokens: 4096,
62 system: None,
63 model_override: None,
64 }
65 }
66}
67
68/// Typed completion with explicit options. Same ServiceDef-aware schema-normalizer path as
69/// [`complete`], parameterized by [`CompleteOptions`].
70///
71/// Callers never touch `schemars` or `serde_json` directly (SC#1).
72///
73/// # Errors
74///
75/// - `Error::Provider` — the LLM provider returned a non-success HTTP response.
76/// - `Error::Deserialization` — the provider response was not valid JSON for `T`.
77/// - `Error::Unsupported` — the client does not support non-streaming completions.
78/// - `Error::SchemaError` — the type's schema could not be serialized.
79pub async fn complete_with<T>(
80 client: &dyn LlmClient,
81 prompt: &str,
82 opts: CompleteOptions,
83) -> Result<T, Error>
84where
85 T: schemars::JsonSchema + serde::de::DeserializeOwned,
86{
87 let raw_schema = serde_json::to_value(schemars::schema_for!(T))
88 .map_err(|e| Error::SchemaError(format!("schema_for serialization: {e}")))?;
89 let normalized = schema::for_structured_output(raw_schema);
90
91 let request = CompletionRequest {
92 system: opts.system,
93 messages: vec![Message {
94 role: Role::User,
95 content: prompt.to_string(),
96 tool_call_id: None,
97 }],
98 max_tokens: opts.max_tokens,
99 model_override: opts.model_override,
100 schema: Some(normalized),
101 tools: None,
102 tool_choice: None,
103 };
104
105 let text = client.complete(request).await?;
106 serde_json::from_str::<T>(&text).map_err(|e| Error::Deserialization(e.to_string()))
107}
108
109/// Typed completion: generate a structured `T` from a prompt.
110///
111/// Delegates to [`complete_with`] with [`CompleteOptions::default`].
112/// Internally calls `schemars::schema_for::<T>()`, normalizes the schema via
113/// `schema::for_structured_output`, builds a `CompletionRequest` with the normalized
114/// schema, calls `client.complete`, and deserializes the JSON response into `T`.
115///
116/// Callers never touch `schemars` or `serde_json` directly (SC#1).
117///
118/// # Errors
119///
120/// - `Error::Provider` — the LLM provider returned a non-success HTTP response.
121/// - `Error::Deserialization` — the provider response was not valid JSON for `T`.
122/// - `Error::Unsupported` — the client does not support non-streaming completions.
123pub async fn complete<T>(client: &dyn LlmClient, prompt: &str) -> Result<T, Error>
124where
125 T: schemars::JsonSchema + serde::de::DeserializeOwned,
126{
127 complete_with(client, prompt, CompleteOptions::default()).await
128}
129
130#[cfg(test)]
131mod tests {
132 use super::*;
133 use async_trait::async_trait;
134 use schemars::JsonSchema;
135 use serde::Deserialize;
136 use std::sync::Mutex;
137
138 use crate::client::{CompletionRequest, TokenStream};
139
140 #[derive(Debug, Deserialize, JsonSchema, PartialEq)]
141 struct MyOutput {
142 value: String,
143 }
144
145 /// Minimal struct for complete_with / delegation tests.
146 #[derive(Debug, Deserialize, JsonSchema, PartialEq)]
147 struct SimpleStruct {
148 value: i64,
149 }
150
151 /// Mock LLM client that always returns the same fixed JSON string.
152 struct ConstClient(String);
153
154 #[async_trait]
155 impl LlmClient for ConstClient {
156 fn default_model(&self) -> &str {
157 "test"
158 }
159
160 async fn complete(&self, _: CompletionRequest) -> Result<String, Error> {
161 Ok(self.0.clone())
162 }
163
164 async fn complete_stream(&self, _: CompletionRequest) -> Result<TokenStream, Error> {
165 Err(Error::Unsupported)
166 }
167
168 async fn embed(&self, _: &str) -> Result<Vec<f32>, Error> {
169 Err(Error::Unsupported)
170 }
171 }
172
173 /// Mock LLM client that captures the last CompletionRequest for assertion.
174 struct CapturingClient {
175 response: String,
176 captured: Mutex<Option<CompletionRequest>>,
177 }
178
179 impl CapturingClient {
180 fn new(response: &str) -> Self {
181 Self {
182 response: response.to_string(),
183 captured: Mutex::new(None),
184 }
185 }
186 }
187
188 #[async_trait]
189 impl LlmClient for CapturingClient {
190 fn default_model(&self) -> &str {
191 "test"
192 }
193
194 async fn complete(&self, req: CompletionRequest) -> Result<String, Error> {
195 *self.captured.lock().unwrap() = Some(req);
196 Ok(self.response.clone())
197 }
198
199 async fn complete_stream(&self, _: CompletionRequest) -> Result<TokenStream, Error> {
200 Err(Error::Unsupported)
201 }
202
203 async fn embed(&self, _: &str) -> Result<Vec<f32>, Error> {
204 Err(Error::Unsupported)
205 }
206 }
207
208 /// SC#1: `complete::<T>()` round-trips a typed value via a mock client.
209 ///
210 /// The caller never imports schemars or serde_json — only `complete`, the client
211 /// trait, and the output type are needed. The mock returns a fixed JSON string
212 /// and the function deserializes it into the typed struct.
213 #[tokio::test]
214 async fn complete_returns_typed_result() {
215 let client = ConstClient(r#"{"value":"hello"}"#.to_string());
216 let result = complete::<MyOutput>(&client, "test prompt").await.unwrap();
217 assert_eq!(result.value, "hello");
218 }
219
220 /// Deserialization errors are reported as `Error::Deserialization`.
221 #[tokio::test]
222 async fn complete_propagates_deserialization_error() {
223 let client = ConstClient(r#"{"wrong_field":"hello"}"#.to_string());
224 let result = complete::<MyOutput>(&client, "test prompt").await;
225 // MyOutput has a required `value` field; missing it causes a deserialization error.
226 // The error type should not be Unsupported or Provider.
227 match result {
228 Err(Error::Deserialization(_)) => {}
229 other => panic!("expected Deserialization error, got: {other:?}"),
230 }
231 }
232
233 /// `CompleteOptions::default()` produces the canonical zero-config values.
234 #[test]
235 fn complete_options_default() {
236 let opts = CompleteOptions::default();
237 assert_eq!(opts.max_tokens, 4096);
238 assert!(opts.system.is_none());
239 assert!(opts.model_override.is_none());
240 }
241
242 /// `complete_with::<T>()` forwards options to the CompletionRequest fields.
243 #[tokio::test]
244 async fn complete_with_uses_provided_max_tokens() {
245 let client = CapturingClient::new(r#"{"value":1}"#);
246 let opts = CompleteOptions {
247 max_tokens: 9999,
248 system: Some("sys".to_string()),
249 model_override: Some("m".to_string()),
250 };
251 let _: SimpleStruct = complete_with(&client, "p", opts).await.unwrap();
252 let req = client.captured.lock().unwrap().take().unwrap();
253 assert_eq!(req.max_tokens, 9999);
254 assert_eq!(req.system, Some("sys".to_string()));
255 assert_eq!(req.model_override, Some("m".to_string()));
256 assert!(req.schema.is_some());
257 }
258
259 /// `complete::<T>()` is a thin delegate: it passes `CompleteOptions::default()` values.
260 #[tokio::test]
261 async fn complete_delegates_to_complete_with() {
262 let client = CapturingClient::new(r#"{"value":1}"#);
263 let _: SimpleStruct = complete(&client, "p").await.unwrap();
264 let req = client.captured.lock().unwrap().take().unwrap();
265 assert_eq!(req.max_tokens, 4096);
266 assert!(req.system.is_none());
267 assert!(req.model_override.is_none());
268 assert!(req.schema.is_some());
269 }
270}