use std::marker::PhantomData;
use std::sync::Arc;
use entelix_core::chat::ChatModel;
use entelix_core::codecs::Codec;
use entelix_core::ir::Message;
use entelix_core::transports::Transport;
use entelix_core::{ExecutionContext, Result};
use crate::runnable::Runnable;
pub struct StructuredOutputAdapter<O, C: Codec, T: Transport> {
inner: Arc<ChatModel<C, T>>,
_phantom: PhantomData<fn() -> O>,
}
impl<O, C: Codec, T: Transport> StructuredOutputAdapter<O, C, T> {
pub fn new(model: ChatModel<C, T>) -> Self {
Self {
inner: Arc::new(model),
_phantom: PhantomData,
}
}
pub const fn from_arc(model: Arc<ChatModel<C, T>>) -> Self {
Self {
inner: model,
_phantom: PhantomData,
}
}
pub const fn inner(&self) -> &Arc<ChatModel<C, T>> {
&self.inner
}
}
impl<O, C: Codec, T: Transport> Clone for StructuredOutputAdapter<O, C, T> {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
_phantom: PhantomData,
}
}
}
impl<O, C: Codec, T: Transport> std::fmt::Debug for StructuredOutputAdapter<O, C, T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StructuredOutputAdapter")
.field("output", &std::any::type_name::<O>())
.finish()
}
}
#[async_trait::async_trait]
impl<O, C, T> Runnable<Vec<Message>, O> for StructuredOutputAdapter<O, C, T>
where
O: schemars::JsonSchema + serde::de::DeserializeOwned + Send + 'static,
C: Codec,
T: Transport,
{
async fn invoke(&self, input: Vec<Message>, ctx: &ExecutionContext) -> Result<O> {
self.inner.complete_typed::<O>(input, ctx).await
}
}
pub trait ChatModelExt<C: Codec, T: Transport>: Sized {
fn with_structured_output<O>(self) -> StructuredOutputAdapter<O, C, T>
where
O: schemars::JsonSchema + serde::de::DeserializeOwned + Send + 'static;
}
impl<C, T> ChatModelExt<C, T> for ChatModel<C, T>
where
C: Codec,
T: Transport,
{
fn with_structured_output<O>(self) -> StructuredOutputAdapter<O, C, T>
where
O: schemars::JsonSchema + serde::de::DeserializeOwned + Send + 'static,
{
StructuredOutputAdapter::new(self)
}
}