use std::collections::{HashMap, HashSet};
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use objectiveai_sdk::agent::completions::request::AgentCompletionCreateParams;
use objectiveai_sdk::agent::completions::response::streaming::{
AgentCompletionChunk, AgentCompletionIds,
};
use objectiveai_sdk::functions::executions::request::FunctionExecutionCreateParams;
use objectiveai_sdk::functions::executions::response::streaming::FunctionExecutionChunk;
use objectiveai_sdk::vector::completions::request::VectorCompletionCreateParams;
use objectiveai_sdk::vector::completions::response::streaming::VectorCompletionChunk;
use serde::Serialize;
use tokio::sync::{mpsc, oneshot, watch};
use tokio::task::JoinHandle;
use crate::db::Pool;
use super::row::{RowValue, RowsIter};
use super::rows::{
agent_completion_chunk_rows, function_execution_chunk_rows, vector_completion_chunk_rows,
};
use super::shadow::{Shadow, WriteOp};
use super::write::{
Tier, insert_request_blob, insert_request_messages_row, insert_response_blob,
update_response_blob, write_value,
};
pub trait WriterChunk {
fn primary_id(&self) -> &str;
}
impl WriterChunk for AgentCompletionChunk {
fn primary_id(&self) -> &str {
self.id.as_str()
}
}
impl WriterChunk for VectorCompletionChunk {
fn primary_id(&self) -> &str {
self.id.as_str()
}
}
impl WriterChunk for FunctionExecutionChunk {
fn primary_id(&self) -> &str {
self.id.as_str()
}
}
pub trait ChunkPush {
fn push(&mut self, other: &Self);
}
impl ChunkPush for AgentCompletionChunk {
fn push(&mut self, other: &Self) {
AgentCompletionChunk::push(self, other);
}
}
impl ChunkPush for VectorCompletionChunk {
fn push(&mut self, other: &Self) {
VectorCompletionChunk::push(self, other);
}
}
impl ChunkPush for FunctionExecutionChunk {
fn push(&mut self, other: &Self) {
FunctionExecutionChunk::push(self, other);
}
}
pub struct LogWriter<C> {
tx: mpsc::UnboundedSender<C>,
handle: JoinHandle<Result<(), crate::error::Error>>,
written_rx: watch::Receiver<bool>,
_chunk: PhantomData<fn() -> C>,
}
impl<C> LogWriter<C> {
pub fn write(&self, chunk: C) -> Result<(), crate::error::Error> {
self.tx
.send(chunk)
.map_err(|_| crate::error::Error::Instance(
"log writer task has exited (earlier write failed)".to_string(),
))
}
pub fn written_once(&self) -> bool {
*self.written_rx.borrow()
}
pub async fn wait_written_once(&self) -> Result<(), crate::error::Error> {
let mut rx = self.written_rx.clone();
rx.wait_for(|b| *b)
.await
.map(|_| ())
.map_err(|_| crate::error::Error::Instance(
"log writer task exited before completing its first write".to_string(),
))
}
pub async fn finalize(self) -> Result<(), crate::error::Error> {
let LogWriter { tx, handle, .. } = self;
drop(tx);
match handle.await {
Ok(inner) => inner,
Err(e) => Err(crate::error::Error::Instance(
format!("log writer task: {e}"),
)),
}
}
}
struct LogWriterState<C> {
pool: Pool,
tier: Tier,
request_body: serde_json::Value,
sender_agent_instance_hierarchy: String,
rows_fn: for<'a> fn(&'a C) -> RowsIter<'a>,
primary_id: Option<String>,
shadow: Shadow,
last_response_blob: Option<Vec<u8>>,
request_written: bool,
seen_agents: HashSet<String>,
_chunk: PhantomData<fn() -> C>,
}
impl<C> LogWriterState<C> {
fn new(
pool: Pool,
tier: Tier,
request_body: serde_json::Value,
sender_agent_instance_hierarchy: String,
rows_fn: for<'a> fn(&'a C) -> RowsIter<'a>,
) -> Self {
Self {
pool,
tier,
request_body,
sender_agent_instance_hierarchy,
rows_fn,
primary_id: None,
shadow: Shadow::new(),
last_response_blob: None,
request_written: false,
seen_agents: HashSet::new(),
_chunk: PhantomData,
}
}
async fn apply_chunk(&mut self, chunk: &C) -> Result<(), crate::error::Error>
where
C: WriterChunk + AgentCompletionIds + Serialize + Send + Sync,
{
if self.primary_id.is_none() {
self.primary_id = Some(chunk.primary_id().to_string());
}
let response_id = self.primary_id.clone().expect("set above");
let created_at_seed = now_secs() as i64;
if !self.request_written {
let request_body = self.request_body.clone();
insert_request_blob(
&self.pool,
self.tier,
&response_id,
&request_body,
&self.sender_agent_instance_hierarchy,
created_at_seed,
)
.await?;
self.request_written = true;
}
let mut buckets: HashMap<&str, Vec<(WriteOp, RowValue<'_>)>> = HashMap::new();
for value in (self.rows_fn)(chunk) {
let key = value.agent_instance_hierarchy();
match self.shadow.record(&value) {
WriteOp::Skip => continue,
op => buckets.entry(key).or_default().push((op, value)),
}
}
let response_bytes = serde_json::to_vec(chunk)?;
let blob_op = match &self.last_response_blob {
Some(prev) if prev == &response_bytes => WriteOp::Skip,
Some(_) => WriteOp::Update,
None => WriteOp::Insert,
};
let pool = &self.pool;
let tier = self.tier;
let resp_id = response_id.as_str();
let seen_agents = &mut self.seen_agents;
let bucket_futures: Vec<
Pin<Box<dyn Future<Output = Result<(), crate::error::Error>> + Send + '_>>,
> = buckets
.into_iter()
.map(|(hier, items)| {
let needs_request_row = !seen_agents.contains(hier);
if needs_request_row {
seen_agents.insert(hier.to_string());
}
Box::pin(async move {
if needs_request_row {
insert_request_messages_row(
pool,
tier,
resp_id,
hier,
created_at_seed,
)
.await?;
}
for (op, value) in &items {
write_value(pool, *op, value, created_at_seed).await?;
}
Ok::<(), crate::error::Error>(())
})
as Pin<
Box<
dyn Future<Output = Result<(), crate::error::Error>>
+ Send
+ '_,
>,
>
})
.collect();
let content_fut = futures::future::try_join_all(bucket_futures);
let blob_fut = async {
match blob_op {
WriteOp::Insert => {
insert_response_blob(
pool,
tier,
resp_id,
chunk,
created_at_seed,
)
.await
}
WriteOp::Update => {
update_response_blob(pool, tier, resp_id, chunk, created_at_seed).await
}
WriteOp::Skip => Ok(()),
}
};
let (content_res, blob_res) = tokio::join!(content_fut, blob_fut);
content_res?;
blob_res?;
if blob_op != WriteOp::Skip {
self.last_response_blob = Some(response_bytes);
}
Ok(())
}
}
async fn listener_loop<C>(
mut rx: mpsc::UnboundedReceiver<C>,
mut state: LogWriterState<C>,
ready_tx: oneshot::Sender<String>,
written_tx: watch::Sender<bool>,
) -> Result<(), crate::error::Error>
where
C: WriterChunk + AgentCompletionIds + ChunkPush + Serialize + Send + Sync,
{
let mut ready_tx = Some(ready_tx);
let mut written_fired = false;
while let Some(first) = rx.recv().await {
let mut agg = first;
while let Ok(next) = rx.try_recv() {
agg.push(&next);
}
state.apply_chunk(&agg).await?;
if !written_fired {
let _ = written_tx.send(true);
written_fired = true;
}
if let Some(tx) = ready_tx.take() {
match state.primary_id.as_deref() {
Some(id) => {
let _ = tx.send(id.to_string());
}
None => {
ready_tx = Some(tx);
}
}
}
}
Ok(())
}
fn now_secs() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0)
}
fn spawn_writer<C>(
pool: Pool,
tier: Tier,
request_body: serde_json::Value,
sender_agent_instance_hierarchy: String,
rows_fn: for<'a> fn(&'a C) -> RowsIter<'a>,
) -> (LogWriter<C>, oneshot::Receiver<String>)
where
C: WriterChunk + AgentCompletionIds + ChunkPush + Serialize + Send + Sync + 'static,
{
let (tx, rx) = mpsc::unbounded_channel();
let (ready_tx, ready_rx) = oneshot::channel();
let (written_tx, written_rx) = watch::channel(false);
let state = LogWriterState::new(
pool,
tier,
request_body,
sender_agent_instance_hierarchy,
rows_fn,
);
let handle = tokio::spawn(listener_loop(rx, state, ready_tx, written_tx));
(
LogWriter {
tx,
handle,
written_rx,
_chunk: PhantomData,
},
ready_rx,
)
}
pub fn write_agent_completion(
pool: &Pool,
params: &AgentCompletionCreateParams,
sender_agent_instance_hierarchy: String,
) -> Result<
(LogWriter<AgentCompletionChunk>, oneshot::Receiver<String>),
crate::error::Error,
> {
let body = serde_json::to_value(params)?;
Ok(spawn_writer(
pool.clone(),
Tier::Agent,
body,
sender_agent_instance_hierarchy,
agent_completion_chunk_rows,
))
}
pub fn write_vector_completion(
pool: &Pool,
params: &VectorCompletionCreateParams,
sender_agent_instance_hierarchy: String,
) -> Result<
(LogWriter<VectorCompletionChunk>, oneshot::Receiver<String>),
crate::error::Error,
> {
let body = serde_json::to_value(params)?;
Ok(spawn_writer(
pool.clone(),
Tier::Vector,
body,
sender_agent_instance_hierarchy,
vector_completion_chunk_rows,
))
}
pub fn write_function_execution(
pool: &Pool,
params: &FunctionExecutionCreateParams,
sender_agent_instance_hierarchy: String,
) -> Result<
(LogWriter<FunctionExecutionChunk>, oneshot::Receiver<String>),
crate::error::Error,
> {
let body = serde_json::to_value(params)?;
Ok(spawn_writer(
pool.clone(),
Tier::Function,
body,
sender_agent_instance_hierarchy,
function_execution_chunk_rows,
))
}