use std::{any::Any, future::Future, pin::Pin, sync::Arc};
use artificial_core::{
error::{ArtificialError, Result},
generic::{GenericChatCompletionResponse, GenericUsageReport, ResponseContent},
provider::PromptExecutionProvider,
template::{IntoPrompt, PromptTemplate},
};
use schemars::{JsonSchema, SchemaGenerator, r#gen::SchemaSettings};
use serde_json::json;
use crate::{
OpenAiAdapter,
api_v1::{ChatCompletionMessage, ChatCompletionRequest, FinishReason},
error::OpenAiError,
model_map::map_model,
};
impl PromptExecutionProvider for OpenAiAdapter {
type Message = ChatCompletionMessage;
fn prompt_execute<'a, 'p, P>(
&'a self,
prompt: P,
) -> Pin<Box<dyn Future<Output = Result<GenericChatCompletionResponse<P::Output>>> + Send + 'p>>
where
'a: 'p,
P: PromptTemplate + Send + Sync + 'p,
<P as IntoPrompt>::Message: Into<Self::Message>,
{
let client = Arc::clone(&self.client);
let messages = prompt.into_prompt().into_iter().map(Into::into).collect();
Box::pin(async move {
let response_format = derive_response_format::<P::Output>()?;
let model = map_model(&P::MODEL).ok_or(ArtificialError::InvalidRequest(format!(
"backend does not support selected model: {:?}",
P::MODEL
)))?;
let request =
ChatCompletionRequest::new(model.into(), messages).response_format(response_format);
let response = client.chat_completion(request).await?;
let usage_report = GenericUsageReport {
prompt_tokens: response.usage.prompt_tokens as i64,
completion_tokens: response.usage.completion_tokens as i64,
total_tokens: response.usage.total_tokens as i64,
};
let Some(first_choice) = response.choices.first() else {
return Err(OpenAiError::Format("response has no choices".into()).into());
};
match &first_choice.finish_reason {
None | Some(FinishReason::Stop) => {
let content =
first_choice
.message
.content
.as_ref()
.ok_or(OpenAiError::Format(
"invalid response: empty content".into(),
))?;
let content = serde_json::from_str(content.as_str())?;
let response = GenericChatCompletionResponse {
content: ResponseContent::Finished(content),
usage: Some(usage_report),
};
Ok(response)
}
Some(other) => Err(OpenAiError::Format(format!(
"unhandled finish reason on API: {other:?}"
))
.into()),
}
})
}
}
fn derive_response_format<T>() -> Result<serde_json::Value>
where
T: JsonSchema + Any,
{
let requested_type = std::any::TypeId::of::<T>();
let json_value_type = std::any::TypeId::of::<serde_json::Value>();
if requested_type.eq(&json_value_type) {
return Ok(json!({ "type": "json_object" }));
}
let schema_json = {
let mut settings = SchemaSettings::draft07();
settings.inline_subschemas = true;
let mut generator = SchemaGenerator::new(settings);
let root_schema = generator.root_schema_for::<T>();
serde_json::to_value(root_schema)?
};
let schema_title = schema_json
.as_object()
.and_then(|o| o.get("title"))
.and_then(|t| t.as_str())
.map(str::to_owned)
.ok_or(ArtificialError::InvalidRequest(
"json schema has no title".into(),
))?;
Ok(json!({
"type": "json_schema",
"json_schema": {
"strict": true,
"name": schema_title,
"schema": schema_json,
}
}))
}