artificial_openai/
backend.rs1use std::{any::Any, future::Future, pin::Pin, sync::Arc};
2
3use artificial_core::{
4 backend::Backend,
5 error::{ArtificialError, Result},
6 template::{IntoPrompt, PromptTemplate},
7};
8use schemars::{JsonSchema, SchemaGenerator, r#gen::SchemaSettings};
9use serde::Deserialize;
10use serde_json::json;
11
12use crate::{
13 OpenAiAdapter,
14 api_v1::chat_completion::{ChatCompletionMessage, ChatCompletionRequest, FinishReason},
15 error::OpenAiError,
16 model_map::map_model,
17};
18
19impl Backend for OpenAiAdapter {
35 type Message = ChatCompletionMessage;
37
38 fn chat_complete<P>(&self, prompt: P) -> Pin<Box<dyn Future<Output = Result<P::Output>> + Send>>
43 where
44 P: PromptTemplate + Send + 'static,
45 <P as IntoPrompt>::Message: Into<Self::Message>,
47 {
48 let client = Arc::clone(&self.client);
49
50 Box::pin(async move {
51 let response_format = derive_response_format::<P::Output>()?;
52
53 let messages = prompt.into_prompt().into_iter().map(Into::into).collect();
54
55 let model = map_model(P::MODEL).ok_or(ArtificialError::InvalidRequest(format!(
56 "backend does not support selected model: {:?}",
57 P::MODEL
58 )))?;
59
60 let request =
61 ChatCompletionRequest::new(model.into(), messages).response_format(response_format);
62
63 let response = client.chat_completion(request).await?;
64
65 let Some(first_choice) = response.choices.first() else {
66 return Err(OpenAiError::Format("response has no choices".into()).into());
67 };
68
69 match &first_choice.finish_reason {
70 None | Some(FinishReason::Stop) => {
71 let content =
72 first_choice
73 .message
74 .content
75 .as_ref()
76 .ok_or(OpenAiError::Format(
77 "invalid response: empty content".into(),
78 ))?;
79 parse_response(content)
80 }
81 Some(other) => Err(OpenAiError::Format(format!(
82 "unhandled finish reason on API: {other:?}"
83 ))
84 .into()),
85 }
86 })
87 }
88}
89
90fn parse_response<T>(content: &str) -> Result<T>
92where
93 T: for<'de> Deserialize<'de>,
94{
95 Ok(serde_json::from_str(content)?)
96}
97
98fn derive_response_format<T>() -> Result<serde_json::Value>
103where
104 T: JsonSchema + Any,
105{
106 let requested_type = std::any::TypeId::of::<T>();
107 let json_value_type = std::any::TypeId::of::<serde_json::Value>();
108
109 if requested_type.eq(&json_value_type) {
111 return Ok(json!({ "type": "json_object" }));
112 }
113
114 let schema_json = {
116 let mut settings = SchemaSettings::draft07();
117 settings.inline_subschemas = true;
118
119 let mut generator = SchemaGenerator::new(settings);
120 let root_schema = generator.root_schema_for::<T>();
121
122 serde_json::to_value(root_schema)?
123 };
124
125 let schema_title = schema_json
127 .as_object()
128 .and_then(|o| o.get("title"))
129 .and_then(|t| t.as_str())
130 .map(str::to_owned)
131 .ok_or(ArtificialError::InvalidRequest(
132 "json schema has no title".into(),
133 ))?;
134
135 Ok(json!({
136 "type": "json_schema",
137 "json_schema": {
138 "strict": true,
139 "name": schema_title,
140 "schema": schema_json,
141 }
142 }))
143}