artificial_openai/
provider_impl_prompt.rs1use std::{any::Any, future::Future, pin::Pin, sync::Arc};
2
3use artificial_core::{
4 error::{ArtificialError, Result},
5 generic::{GenericChatCompletionResponse, GenericUsageReport, ResponseContent},
6 provider::PromptExecutionProvider,
7 template::{IntoPrompt, PromptTemplate},
8};
9use schemars::{JsonSchema, SchemaGenerator, r#gen::SchemaSettings};
10use serde_json::json;
11
12use crate::{
13 OpenAiAdapter,
14 api_v1::{ChatCompletionMessage, ChatCompletionRequest, FinishReason},
15 error::OpenAiError,
16 model_map::map_model,
17};
18
19impl PromptExecutionProvider for OpenAiAdapter {
35 type Message = ChatCompletionMessage;
37
38 fn prompt_execute<'a, 'p, P>(
43 &'a self,
44 prompt: P,
45 ) -> Pin<Box<dyn Future<Output = Result<GenericChatCompletionResponse<P::Output>>> + Send + 'p>>
46 where
47 'a: 'p,
48 P: PromptTemplate + Send + Sync + 'p,
49 <P as IntoPrompt>::Message: Into<Self::Message>,
50 {
51 let client = Arc::clone(&self.client);
52
53 let messages = prompt.into_prompt().into_iter().map(Into::into).collect();
54
55 Box::pin(async move {
56 let response_format = derive_response_format::<P::Output>()?;
57
58 let model = map_model(&P::MODEL).ok_or(ArtificialError::InvalidRequest(format!(
59 "backend does not support selected model: {:?}",
60 P::MODEL
61 )))?;
62
63 let request =
64 ChatCompletionRequest::new(model.into(), messages).response_format(response_format);
65
66 let response = client.chat_completion(request).await?;
67
68 let usage_report = GenericUsageReport {
69 prompt_tokens: response.usage.prompt_tokens as i64,
70 completion_tokens: response.usage.completion_tokens as i64,
71 total_tokens: response.usage.total_tokens as i64,
72 };
73
74 let Some(first_choice) = response.choices.first() else {
75 return Err(OpenAiError::Format("response has no choices".into()).into());
76 };
77
78 match &first_choice.finish_reason {
79 None | Some(FinishReason::Stop) => {
80 let content =
81 first_choice
82 .message
83 .content
84 .as_ref()
85 .ok_or(OpenAiError::Format(
86 "invalid response: empty content".into(),
87 ))?;
88 let content = serde_json::from_str(content.as_str())?;
89 let response = GenericChatCompletionResponse {
90 content: ResponseContent::Finished(content),
91 usage: Some(usage_report),
92 };
93 Ok(response)
94 }
95 Some(other) => Err(OpenAiError::Format(format!(
96 "unhandled finish reason on API: {other:?}"
97 ))
98 .into()),
99 }
100 })
101 }
102}
103
104fn derive_response_format<T>() -> Result<serde_json::Value>
109where
110 T: JsonSchema + Any,
111{
112 let requested_type = std::any::TypeId::of::<T>();
113 let json_value_type = std::any::TypeId::of::<serde_json::Value>();
114
115 if requested_type.eq(&json_value_type) {
117 return Ok(json!({ "type": "json_object" }));
118 }
119
120 let schema_json = {
122 let mut settings = SchemaSettings::draft07();
123 settings.inline_subschemas = true;
124
125 let mut generator = SchemaGenerator::new(settings);
126 let root_schema = generator.root_schema_for::<T>();
127
128 serde_json::to_value(root_schema)?
129 };
130
131 let schema_title = schema_json
133 .as_object()
134 .and_then(|o| o.get("title"))
135 .and_then(|t| t.as_str())
136 .map(str::to_owned)
137 .ok_or(ArtificialError::InvalidRequest(
138 "json schema has no title".into(),
139 ))?;
140
141 Ok(json!({
142 "type": "json_schema",
143 "json_schema": {
144 "strict": true,
145 "name": schema_title,
146 "schema": schema_json,
147 }
148 }))
149}