artificial_openai/
provider_impl_prompt.rs

1use std::{any::Any, future::Future, pin::Pin, sync::Arc};
2
3use artificial_core::{
4    error::{ArtificialError, Result},
5    generic::{GenericChatCompletionResponse, GenericUsageReport, ResponseContent},
6    provider::PromptExecutionProvider,
7    template::{IntoPrompt, PromptTemplate},
8};
9use schemars::{JsonSchema, SchemaGenerator, r#gen::SchemaSettings};
10use serde_json::json;
11
12use crate::{
13    OpenAiAdapter,
14    api_v1::{ChatCompletionMessage, ChatCompletionRequest, FinishReason},
15    error::OpenAiError,
16    model_map::map_model,
17};
18
19/// Implementation of [`ChatCompletionProvider`] for the [`OpenAiAdapter`].
20///
21/// The type is only a thin glue layer—almost all heavy lifting is done by the
22/// inner `OpenAiClient` (HTTP) and the adapter’s config (API key, base URL, …).
23///
24/// Responsibilities:
25///
26/// 1. **Convert** the generic prompt into OpenAI‐compatible chat messages.
27/// 2. **Enrich** the request with a JSON Schema derived from `Prompt::Output`.
28/// 3. **Call** the `/v1/chat/completions` endpoint and bubble up transport errors.
29/// 4. **Validate & deserialize** the returned JSON into `Prompt::Output`.
30///
31/// The implementation purposefully rejects any *streaming* or *multi-choice*
32/// responses for now; this keeps the surface minimal and makes error handling
33/// easier to reason about.
34impl PromptExecutionProvider for OpenAiAdapter {
35    /// Provider-specific chat message type.
36    type Message = ChatCompletionMessage;
37
38    /// Perform a non-streaming chat completion and deserialize the result.
39    ///
40    /// The method is object-safe by returning a boxed `Future` rather than using
41    /// async/await syntax directly.
42    fn prompt_execute<'a, 'p, P>(
43        &'a self,
44        prompt: P,
45    ) -> Pin<Box<dyn Future<Output = Result<GenericChatCompletionResponse<P::Output>>> + Send + 'p>>
46    where
47        'a: 'p,
48        P: PromptTemplate + Send + Sync + 'p,
49        <P as IntoPrompt>::Message: Into<Self::Message>,
50    {
51        let client = Arc::clone(&self.client);
52
53        let messages = prompt.into_prompt().into_iter().map(Into::into).collect();
54
55        Box::pin(async move {
56            let response_format = derive_response_format::<P::Output>()?;
57
58            let model = map_model(&P::MODEL).ok_or(ArtificialError::InvalidRequest(format!(
59                "backend does not support selected model: {:?}",
60                P::MODEL
61            )))?;
62
63            let request =
64                ChatCompletionRequest::new(model.into(), messages).response_format(response_format);
65
66            let response = client.chat_completion(request).await?;
67
68            let usage_report = GenericUsageReport {
69                prompt_tokens: response.usage.prompt_tokens as i64,
70                completion_tokens: response.usage.completion_tokens as i64,
71                total_tokens: response.usage.total_tokens as i64,
72            };
73
74            let Some(first_choice) = response.choices.first() else {
75                return Err(OpenAiError::Format("response has no choices".into()).into());
76            };
77
78            match &first_choice.finish_reason {
79                None | Some(FinishReason::Stop) => {
80                    let content =
81                        first_choice
82                            .message
83                            .content
84                            .as_ref()
85                            .ok_or(OpenAiError::Format(
86                                "invalid response: empty content".into(),
87                            ))?;
88                    let content = serde_json::from_str(content.as_str())?;
89                    let response = GenericChatCompletionResponse {
90                        content: ResponseContent::Finished(content),
91                        usage: Some(usage_report),
92                    };
93                    Ok(response)
94                }
95                Some(other) => Err(OpenAiError::Format(format!(
96                    "unhandled finish reason on API: {other:?}"
97                ))
98                .into()),
99            }
100        })
101    }
102}
103
104/// Produce the `response_format` object expected by OpenAI.
105///
106/// * If `T == serde_json::Value` we ask for an *unstructured* JSON blob.
107/// * Otherwise we inline a full JSON Schema generated by `schemars`.
108fn derive_response_format<T>() -> Result<serde_json::Value>
109where
110    T: JsonSchema + Any,
111{
112    let requested_type = std::any::TypeId::of::<T>();
113    let json_value_type = std::any::TypeId::of::<serde_json::Value>();
114
115    // Fast-path: caller wants raw JSON.
116    if requested_type.eq(&json_value_type) {
117        return Ok(json!({ "type": "json_object" }));
118    }
119
120    // Generate inline schema (no $ref) for strict validation.
121    let schema_json = {
122        let mut settings = SchemaSettings::draft07();
123        settings.inline_subschemas = true;
124
125        let mut generator = SchemaGenerator::new(settings);
126        let root_schema = generator.root_schema_for::<T>();
127
128        serde_json::to_value(root_schema)?
129    };
130
131    // Extract a human-readable title for the schema.
132    let schema_title = schema_json
133        .as_object()
134        .and_then(|o| o.get("title"))
135        .and_then(|t| t.as_str())
136        .map(str::to_owned)
137        .ok_or(ArtificialError::InvalidRequest(
138            "json schema has no title".into(),
139        ))?;
140
141    Ok(json!({
142        "type": "json_schema",
143        "json_schema": {
144            "strict": true,
145            "name": schema_title,
146            "schema": schema_json,
147        }
148    }))
149}