artificial_openai/
backend.rs

1use std::{any::Any, future::Future, pin::Pin, sync::Arc};
2
3use artificial_core::{
4    backend::Backend,
5    error::{ArtificialError, Result},
6    template::{IntoPrompt, PromptTemplate},
7};
8use schemars::{JsonSchema, SchemaGenerator, r#gen::SchemaSettings};
9use serde::Deserialize;
10use serde_json::json;
11
12use crate::{
13    OpenAiAdapter,
14    api_v1::chat_completion::{ChatCompletionMessage, ChatCompletionRequest, FinishReason},
15    error::OpenAiError,
16    model_map::map_model,
17};
18
19/// Implementation of [`Backend`] for the [`OpenAiAdapter`].
20///
21/// The type is only a thin glue layer—almost all heavy lifting is done by the
22/// inner `OpenAiClient` (HTTP) and the adapter’s config (API key, base URL, …).
23///
24/// Responsibilities:
25///
26/// 1. **Convert** the generic prompt into OpenAI‐compatible chat messages.
27/// 2. **Enrich** the request with a JSON Schema derived from `Prompt::Output`.
28/// 3. **Call** the `/v1/chat/completions` endpoint and bubble up transport errors.
29/// 4. **Validate & deserialize** the returned JSON into `Prompt::Output`.
30///
31/// The implementation purposefully rejects any *streaming* or *multi-choice*
32/// responses for now; this keeps the surface minimal and makes error handling
33/// easier to reason about.
34impl Backend for OpenAiAdapter {
35    /// Provider-specific chat message type.
36    type Message = ChatCompletionMessage;
37
38    /// Perform a non-streaming chat completion and deserialize the result.
39    ///
40    /// The method is object-safe by returning a boxed `Future` rather than using
41    /// async/await syntax directly.
42    fn chat_complete<P>(&self, prompt: P) -> Pin<Box<dyn Future<Output = Result<P::Output>> + Send>>
43    where
44        P: PromptTemplate + Send + 'static,
45        // Compile-time check: The prompt must emit messages we can convert.
46        <P as IntoPrompt>::Message: Into<Self::Message>,
47    {
48        let client = Arc::clone(&self.client);
49
50        Box::pin(async move {
51            let response_format = derive_response_format::<P::Output>()?;
52
53            let messages = prompt.into_prompt().into_iter().map(Into::into).collect();
54
55            let model = map_model(P::MODEL).ok_or(ArtificialError::InvalidRequest(format!(
56                "backend does not support selected model: {:?}",
57                P::MODEL
58            )))?;
59
60            let request =
61                ChatCompletionRequest::new(model.into(), messages).response_format(response_format);
62
63            let response = client.chat_completion(request).await?;
64
65            let Some(first_choice) = response.choices.first() else {
66                return Err(OpenAiError::Format("response has no choices".into()).into());
67            };
68
69            match &first_choice.finish_reason {
70                None | Some(FinishReason::Stop) => {
71                    let content =
72                        first_choice
73                            .message
74                            .content
75                            .as_ref()
76                            .ok_or(OpenAiError::Format(
77                                "invalid response: empty content".into(),
78                            ))?;
79                    parse_response(content)
80                }
81                Some(other) => Err(OpenAiError::Format(format!(
82                    "unhandled finish reason on API: {other:?}"
83                ))
84                .into()),
85            }
86        })
87    }
88}
89
90/// Deserialize the provider payload into the caller’s expected type.
91fn parse_response<T>(content: &str) -> Result<T>
92where
93    T: for<'de> Deserialize<'de>,
94{
95    Ok(serde_json::from_str(content)?)
96}
97
98/// Produce the `response_format` object expected by OpenAI.
99///
100/// * If `T == serde_json::Value` we ask for an *unstructured* JSON blob.
101/// * Otherwise we inline a full JSON Schema generated by `schemars`.
102fn derive_response_format<T>() -> Result<serde_json::Value>
103where
104    T: JsonSchema + Any,
105{
106    let requested_type = std::any::TypeId::of::<T>();
107    let json_value_type = std::any::TypeId::of::<serde_json::Value>();
108
109    // Fast-path: caller wants raw JSON.
110    if requested_type.eq(&json_value_type) {
111        return Ok(json!({ "type": "json_object" }));
112    }
113
114    // Generate inline schema (no $ref) for strict validation.
115    let schema_json = {
116        let mut settings = SchemaSettings::draft07();
117        settings.inline_subschemas = true;
118
119        let mut generator = SchemaGenerator::new(settings);
120        let root_schema = generator.root_schema_for::<T>();
121
122        serde_json::to_value(root_schema)?
123    };
124
125    // Extract a human-readable title for the schema.
126    let schema_title = schema_json
127        .as_object()
128        .and_then(|o| o.get("title"))
129        .and_then(|t| t.as_str())
130        .map(str::to_owned)
131        .ok_or(ArtificialError::InvalidRequest(
132            "json schema has no title".into(),
133        ))?;
134
135    Ok(json!({
136        "type": "json_schema",
137        "json_schema": {
138            "strict": true,
139            "name": schema_title,
140            "schema": schema_json,
141        }
142    }))
143}