Skip to main content

ferro_ai/
complete.rs

1//! Typed completion entry point for structured LLM output.
2//!
3//! [`complete`] is the primary surface of the ferro-ai SDK for structured-output use.
4//! Callers never import `schemars` or `serde_json` directly — schema generation,
5//! normalization, and JSON parsing are fully encapsulated (SC#1, D-01).
6//!
7//! ## Usage
8//!
9//! ```rust,ignore
10//! use ferro_ai::{complete, AnthropicClient};
11//! use serde::Deserialize;
12//! use schemars::JsonSchema;
13//!
14//! #[derive(Deserialize, JsonSchema)]
15//! struct OrderSummary { name: String, total: f64 }
16//!
17//! let client = AnthropicClient::from_env().unwrap();
18//! let summary: OrderSummary = complete(&client, "Summarize order #42 as JSON").await?;
19//! ```
20//!
21//! ## Internal flow
22//!
23//! 1. `schemars::schema_for::<T>()` — generate Draft 2020-12 schema from the Rust type.
24//! 2. `schema::for_structured_output(raw)` — normalize for Anthropic/OpenAI constraints;
25//!    activates the ServiceDef-aware projection-enum closing path when T contains
26//!    ferro-projections types in its `$defs` (D-07).
27//! 3. Build `CompletionRequest` with `schema: Some(normalized)`.
28//! 4. `client.complete(request)` — delegate to the configured LLM provider.
29//! 5. `serde_json::from_str::<T>(&text)` — deserialize the JSON response into T.
30//!
31//! ## Plan 04 dependency note
32//!
33//! The `CompletionRequest` struct literal in this file lists exactly the five fields
34//! that exist after Plan 02 (Phase 165): `system`, `messages`, `max_tokens`,
35//! `model_override`, `schema`. Plan 04 adds `tools: Option<Vec<ToolRequest>>` and
36//! `tool_choice: Option<ToolChoice>` — when those fields land, Plan 04 is responsible
37//! for updating this struct literal to add `tools: None, tool_choice: None` (or
38//! restructuring via `Default` if that derive is added).
39
40use crate::client::{CompletionRequest, LlmClient, Message, Role};
41use crate::error::Error;
42use crate::schema;
43
44/// Options controlling a typed completion request.
45///
46/// `max_tokens` caps the response; callers map `FERRO_AI_MAX_TOKENS_PER_COMMAND` onto it.
47/// `system` supplies an optional system prompt for context-heavy completions.
48/// `model_override` selects a non-default model for this request only.
49pub struct CompleteOptions {
50    /// Maximum number of tokens in the completion response.
51    pub max_tokens: u32,
52    /// Optional system prompt prepended before the user message.
53    pub system: Option<String>,
54    /// Override the provider's default model for this request.
55    pub model_override: Option<String>,
56}
57
58impl Default for CompleteOptions {
59    fn default() -> Self {
60        Self {
61            max_tokens: 4096,
62            system: None,
63            model_override: None,
64        }
65    }
66}
67
68/// Typed completion with explicit options. Same ServiceDef-aware schema-normalizer path as
69/// [`complete`], parameterized by [`CompleteOptions`].
70///
71/// Callers never touch `schemars` or `serde_json` directly (SC#1).
72///
73/// # Errors
74///
75/// - `Error::Provider` — the LLM provider returned a non-success HTTP response.
76/// - `Error::Deserialization` — the provider response was not valid JSON for `T`.
77/// - `Error::Unsupported` — the client does not support non-streaming completions.
78/// - `Error::SchemaError` — the type's schema could not be serialized.
79pub async fn complete_with<T>(
80    client: &dyn LlmClient,
81    prompt: &str,
82    opts: CompleteOptions,
83) -> Result<T, Error>
84where
85    T: schemars::JsonSchema + serde::de::DeserializeOwned,
86{
87    let raw_schema = serde_json::to_value(schemars::schema_for!(T))
88        .map_err(|e| Error::SchemaError(format!("schema_for serialization: {e}")))?;
89    let normalized = schema::for_structured_output(raw_schema);
90
91    let request = CompletionRequest {
92        system: opts.system,
93        messages: vec![Message {
94            role: Role::User,
95            content: prompt.to_string(),
96            tool_call_id: None,
97        }],
98        max_tokens: opts.max_tokens,
99        model_override: opts.model_override,
100        schema: Some(normalized),
101        tools: None,
102        tool_choice: None,
103    };
104
105    let text = client.complete(request).await?;
106    serde_json::from_str::<T>(&text).map_err(|e| Error::Deserialization(e.to_string()))
107}
108
109/// Typed completion: generate a structured `T` from a prompt.
110///
111/// Delegates to [`complete_with`] with [`CompleteOptions::default`].
112/// Internally calls `schemars::schema_for::<T>()`, normalizes the schema via
113/// `schema::for_structured_output`, builds a `CompletionRequest` with the normalized
114/// schema, calls `client.complete`, and deserializes the JSON response into `T`.
115///
116/// Callers never touch `schemars` or `serde_json` directly (SC#1).
117///
118/// # Errors
119///
120/// - `Error::Provider` — the LLM provider returned a non-success HTTP response.
121/// - `Error::Deserialization` — the provider response was not valid JSON for `T`.
122/// - `Error::Unsupported` — the client does not support non-streaming completions.
123pub async fn complete<T>(client: &dyn LlmClient, prompt: &str) -> Result<T, Error>
124where
125    T: schemars::JsonSchema + serde::de::DeserializeOwned,
126{
127    complete_with(client, prompt, CompleteOptions::default()).await
128}
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133    use async_trait::async_trait;
134    use schemars::JsonSchema;
135    use serde::Deserialize;
136    use std::sync::Mutex;
137
138    use crate::client::{CompletionRequest, TokenStream};
139
140    #[derive(Debug, Deserialize, JsonSchema, PartialEq)]
141    struct MyOutput {
142        value: String,
143    }
144
145    /// Minimal struct for complete_with / delegation tests.
146    #[derive(Debug, Deserialize, JsonSchema, PartialEq)]
147    struct SimpleStruct {
148        value: i64,
149    }
150
151    /// Mock LLM client that always returns the same fixed JSON string.
152    struct ConstClient(String);
153
154    #[async_trait]
155    impl LlmClient for ConstClient {
156        fn default_model(&self) -> &str {
157            "test"
158        }
159
160        async fn complete(&self, _: CompletionRequest) -> Result<String, Error> {
161            Ok(self.0.clone())
162        }
163
164        async fn complete_stream(&self, _: CompletionRequest) -> Result<TokenStream, Error> {
165            Err(Error::Unsupported)
166        }
167
168        async fn embed(&self, _: &str) -> Result<Vec<f32>, Error> {
169            Err(Error::Unsupported)
170        }
171    }
172
173    /// Mock LLM client that captures the last CompletionRequest for assertion.
174    struct CapturingClient {
175        response: String,
176        captured: Mutex<Option<CompletionRequest>>,
177    }
178
179    impl CapturingClient {
180        fn new(response: &str) -> Self {
181            Self {
182                response: response.to_string(),
183                captured: Mutex::new(None),
184            }
185        }
186    }
187
188    #[async_trait]
189    impl LlmClient for CapturingClient {
190        fn default_model(&self) -> &str {
191            "test"
192        }
193
194        async fn complete(&self, req: CompletionRequest) -> Result<String, Error> {
195            *self.captured.lock().unwrap() = Some(req);
196            Ok(self.response.clone())
197        }
198
199        async fn complete_stream(&self, _: CompletionRequest) -> Result<TokenStream, Error> {
200            Err(Error::Unsupported)
201        }
202
203        async fn embed(&self, _: &str) -> Result<Vec<f32>, Error> {
204            Err(Error::Unsupported)
205        }
206    }
207
208    /// SC#1: `complete::<T>()` round-trips a typed value via a mock client.
209    ///
210    /// The caller never imports schemars or serde_json — only `complete`, the client
211    /// trait, and the output type are needed. The mock returns a fixed JSON string
212    /// and the function deserializes it into the typed struct.
213    #[tokio::test]
214    async fn complete_returns_typed_result() {
215        let client = ConstClient(r#"{"value":"hello"}"#.to_string());
216        let result = complete::<MyOutput>(&client, "test prompt").await.unwrap();
217        assert_eq!(result.value, "hello");
218    }
219
220    /// Deserialization errors are reported as `Error::Deserialization`.
221    #[tokio::test]
222    async fn complete_propagates_deserialization_error() {
223        let client = ConstClient(r#"{"wrong_field":"hello"}"#.to_string());
224        let result = complete::<MyOutput>(&client, "test prompt").await;
225        // MyOutput has a required `value` field; missing it causes a deserialization error.
226        // The error type should not be Unsupported or Provider.
227        match result {
228            Err(Error::Deserialization(_)) => {}
229            other => panic!("expected Deserialization error, got: {other:?}"),
230        }
231    }
232
233    /// `CompleteOptions::default()` produces the canonical zero-config values.
234    #[test]
235    fn complete_options_default() {
236        let opts = CompleteOptions::default();
237        assert_eq!(opts.max_tokens, 4096);
238        assert!(opts.system.is_none());
239        assert!(opts.model_override.is_none());
240    }
241
242    /// `complete_with::<T>()` forwards options to the CompletionRequest fields.
243    #[tokio::test]
244    async fn complete_with_uses_provided_max_tokens() {
245        let client = CapturingClient::new(r#"{"value":1}"#);
246        let opts = CompleteOptions {
247            max_tokens: 9999,
248            system: Some("sys".to_string()),
249            model_override: Some("m".to_string()),
250        };
251        let _: SimpleStruct = complete_with(&client, "p", opts).await.unwrap();
252        let req = client.captured.lock().unwrap().take().unwrap();
253        assert_eq!(req.max_tokens, 9999);
254        assert_eq!(req.system, Some("sys".to_string()));
255        assert_eq!(req.model_override, Some("m".to_string()));
256        assert!(req.schema.is_some());
257    }
258
259    /// `complete::<T>()` is a thin delegate: it passes `CompleteOptions::default()` values.
260    #[tokio::test]
261    async fn complete_delegates_to_complete_with() {
262        let client = CapturingClient::new(r#"{"value":1}"#);
263        let _: SimpleStruct = complete(&client, "p").await.unwrap();
264        let req = client.captured.lock().unwrap().take().unwrap();
265        assert_eq!(req.max_tokens, 4096);
266        assert!(req.system.is_none());
267        assert!(req.model_override.is_none());
268        assert!(req.schema.is_some());
269    }
270}