use rustc_hash::FxHashMap;
use serde::{Serialize, de::DeserializeOwned};
use serde_json::Value;
use std::hash::{Hash, Hasher};
use std::marker::PhantomData;
use thiserror::Error;
use crate::{
channels::{Channel, ErrorsChannel, ExtrasChannel, MessagesChannel},
message::{Message, Role},
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum StateLifecycle {
Durable,
InvocationScoped,
}
#[derive(Debug)]
pub struct StateKey<T> {
namespace: &'static str,
name: &'static str,
schema_version: u32,
lifecycle: StateLifecycle,
_marker: PhantomData<fn() -> T>,
}
impl<T> Clone for StateKey<T> {
fn clone(&self) -> Self {
*self
}
}
impl<T> Copy for StateKey<T> {}
impl<T> PartialEq for StateKey<T> {
fn eq(&self, other: &Self) -> bool {
self.namespace == other.namespace
&& self.name == other.name
&& self.schema_version == other.schema_version
}
}
impl<T> Eq for StateKey<T> {}
impl<T> Hash for StateKey<T> {
fn hash<H: Hasher>(&self, state: &mut H) {
self.namespace.hash(state);
self.name.hash(state);
self.schema_version.hash(state);
}
}
impl<T> StateKey<T> {
pub const fn new(namespace: &'static str, name: &'static str, schema_version: u32) -> Self {
Self {
namespace,
name,
schema_version,
lifecycle: StateLifecycle::Durable,
_marker: PhantomData,
}
}
#[must_use]
pub const fn invocation_scoped(mut self) -> Self {
self.lifecycle = StateLifecycle::InvocationScoped;
self
}
#[must_use]
pub fn lifecycle(&self) -> StateLifecycle {
self.lifecycle
}
#[must_use]
pub fn namespace(&self) -> &'static str {
self.namespace
}
#[must_use]
pub fn name(&self) -> &'static str {
self.name
}
#[must_use]
pub fn schema_version(&self) -> u32 {
self.schema_version
}
#[must_use]
pub fn storage_key(&self) -> String {
format!("{}:{}:v{}", self.namespace, self.name, self.schema_version)
}
}
#[derive(Debug, Error)]
#[cfg_attr(feature = "diagnostics", derive(miette::Diagnostic))]
#[non_exhaustive]
pub enum StateSlotError {
#[error("state slot not found: {key}")]
#[cfg_attr(
feature = "diagnostics",
diagnostic(code(weavegraph::state::slot_missing))
)]
Missing {
key: String,
},
#[error("failed to serialize state slot {key}: {source}")]
#[cfg_attr(
feature = "diagnostics",
diagnostic(code(weavegraph::state::slot_serialize))
)]
Serialize {
key: String,
#[source]
source: serde_json::Error,
},
#[error("failed to deserialize state slot {key}: {source}")]
#[cfg_attr(
feature = "diagnostics",
diagnostic(code(weavegraph::state::slot_deserialize))
)]
Deserialize {
key: String,
#[source]
source: serde_json::Error,
},
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct VersionedState {
pub messages: MessagesChannel,
pub extra: ExtrasChannel,
pub errors: ErrorsChannel,
}
#[derive(Clone, Debug)]
pub struct StateSnapshot {
pub messages: Vec<Message>,
pub messages_version: u32,
pub extra: FxHashMap<String, Value>,
pub extra_version: u32,
pub errors: Vec<crate::channels::errors::ErrorEvent>,
pub errors_version: u32,
}
impl VersionedState {
pub fn new_with_user_message(user_text: &str) -> Self {
let messages = vec![Message::with_role(Role::User, user_text)];
Self {
messages: MessagesChannel::new(messages, 1),
extra: ExtrasChannel::default(),
errors: ErrorsChannel::default(),
}
}
pub fn new_with_messages(messages: Vec<Message>) -> Self {
Self {
messages: MessagesChannel::new(messages, 1),
extra: ExtrasChannel::default(),
errors: ErrorsChannel::default(),
}
}
pub fn builder() -> VersionedStateBuilder {
VersionedStateBuilder::new()
}
#[must_use = "consider using the returned self for method chaining"]
pub fn add_message(&mut self, role: &str, content: &str) -> &mut Self {
self.messages
.get_mut()
.push(Message::with_role(Role::from(role), content));
self
}
#[must_use = "consider using the returned self for method chaining"]
pub fn add_extra(&mut self, key: &str, value: Value) -> &mut Self {
self.extra.get_mut().insert(key.to_string(), value);
self
}
pub fn add_typed_extra<T: Serialize>(
&mut self,
key: StateKey<T>,
value: T,
) -> Result<&mut Self, StateSlotError> {
let storage_key = key.storage_key();
let json_value =
serde_json::to_value(value).map_err(|source| StateSlotError::Serialize {
key: storage_key.clone(),
source,
})?;
self.extra.get_mut().insert(storage_key, json_value);
Ok(self)
}
pub fn snapshot(&self) -> StateSnapshot {
StateSnapshot {
messages: self.messages.snapshot(),
messages_version: self.messages.version(),
extra: self.extra.snapshot(),
extra_version: self.extra.version(),
errors: self.errors.snapshot(),
errors_version: self.errors.version(),
}
}
}
impl StateSnapshot {
pub fn get_typed<T: DeserializeOwned>(
&self,
key: StateKey<T>,
) -> Result<Option<T>, StateSlotError> {
let storage_key = key.storage_key();
self.extra
.get(&storage_key)
.cloned()
.map(|value| {
serde_json::from_value(value).map_err(|source| StateSlotError::Deserialize {
key: storage_key,
source,
})
})
.transpose()
}
pub fn require_typed<T: DeserializeOwned>(
&self,
key: StateKey<T>,
) -> Result<T, StateSlotError> {
let storage_key = key.storage_key();
self.get_typed(key)?
.ok_or(StateSlotError::Missing { key: storage_key })
}
}
#[derive(Debug, Default)]
pub struct VersionedStateBuilder {
messages: Vec<Message>,
extra: FxHashMap<String, Value>,
}
impl VersionedStateBuilder {
fn new() -> Self {
Self::default()
}
pub fn with_user_message(mut self, content: &str) -> Self {
self.messages.push(Message::with_role(Role::User, content));
self
}
pub fn with_assistant_message(mut self, content: &str) -> Self {
self.messages
.push(Message::with_role(Role::Assistant, content));
self
}
pub fn with_system_message(mut self, content: &str) -> Self {
self.messages
.push(Message::with_role(Role::System, content));
self
}
pub fn with_message(mut self, role: &str, content: &str) -> Self {
self.messages
.push(Message::with_role(Role::from(role), content));
self
}
pub fn with_extra(mut self, key: &str, value: Value) -> Self {
self.extra.insert(key.to_string(), value);
self
}
pub fn with_typed_extra<T: Serialize>(
mut self,
key: StateKey<T>,
value: T,
) -> Result<Self, StateSlotError> {
let storage_key = key.storage_key();
let json_value =
serde_json::to_value(value).map_err(|source| StateSlotError::Serialize {
key: storage_key.clone(),
source,
})?;
self.extra.insert(storage_key, json_value);
Ok(self)
}
pub fn build(self) -> VersionedState {
VersionedState {
messages: MessagesChannel::new(self.messages, 1),
extra: ExtrasChannel::new(self.extra, 1),
errors: ErrorsChannel::default(),
}
}
}