use std::sync::Arc;
use crate::checkpoint::CheckpointSaver;
use crate::config::{EntrypointConfig, TaskConfig};
use crate::graph::{StateGraph, TopologyError};
use crate::node::IntoNode;
use crate::runtime::Runtime as CoreRuntime;
use crate::state::{FromState, IntoState, State};
use crate::store::Store;
#[derive(Clone)]
pub struct Runtime<S: State + Default> {
pub previous: Option<serde_json::Value>,
pub checkpointer: Option<Arc<dyn CheckpointSaver>>,
pub store: Option<Arc<dyn Store>>,
pub core: CoreRuntime<S>,
}
impl<S: State + Default + std::fmt::Debug> std::fmt::Debug for Runtime<S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Runtime")
.field("previous", &self.previous)
.field(
"checkpointer",
&self.checkpointer.as_ref().map(|_| "<CheckpointSaver>"),
)
.field("store", &self.store.as_ref().map(|_| "<Store>"))
.field("core", &self.core)
.finish()
}
}
impl<S: State + Default> Runtime<S> {
#[must_use]
pub fn new() -> Self
where
S: Default,
{
Self {
previous: None,
checkpointer: None,
store: None,
core: CoreRuntime::new(),
}
}
#[must_use]
pub fn from_core(core: &CoreRuntime<S>) -> Self {
Self {
previous: core.previous.clone(),
checkpointer: None,
store: core.store.clone(),
core: core.clone(),
}
}
#[must_use]
pub fn from_entrypoint_config(config: &EntrypointConfig) -> Self
where
S: Default,
{
Self {
previous: None,
checkpointer: config.checkpointer.clone(),
store: config.store.clone(),
core: CoreRuntime::new(),
}
}
#[must_use]
pub fn with_previous(mut self, previous: serde_json::Value) -> Self {
self.previous = Some(previous);
self
}
#[must_use]
pub fn with_checkpointer(mut self, checkpointer: Arc<dyn CheckpointSaver>) -> Self {
self.checkpointer = Some(checkpointer);
self
}
#[must_use]
pub fn with_store(mut self, store: Arc<dyn Store>) -> Self {
self.store = Some(store);
self
}
#[must_use]
pub fn with_core(mut self, core: CoreRuntime<S>) -> Self {
self.core = core;
self
}
}
impl<S: State + Default> Default for Runtime<S> {
fn default() -> Self {
Self::new()
}
}
pub fn compile_entrypoint<S: State + Default, I, O, F>(
func: F,
checkpointer: Option<Arc<dyn CheckpointSaver>>,
) -> Result<crate::graph::CompiledGraph<S, I, O>, TopologyError>
where
F: IntoNode<S>,
I: IntoState<S>,
O: FromState<S>,
{
compile_entrypoint_with_config(func, &TaskConfig::default(), checkpointer)
}
pub fn compile_entrypoint_with_config<S: State + Default, I, O, F>(
func: F,
config: &TaskConfig,
checkpointer: Option<Arc<dyn CheckpointSaver>>,
) -> Result<crate::graph::CompiledGraph<S, I, O>, TopologyError>
where
F: IntoNode<S>,
I: IntoState<S>,
O: FromState<S>,
{
let entrypoint_name = config
.name
.clone()
.unwrap_or_else(|| "__entrypoint__".to_string());
let retry_policies = config
.retry_policy
.as_ref()
.map(|p| vec![p.clone()])
.unwrap_or_default();
let mut graph = StateGraph::<S, I, O>::new();
graph.add_node(
&entrypoint_name,
func,
false,
None,
None,
retry_policies,
Vec::new(),
)?;
graph.set_entry_point(&entrypoint_name);
graph.set_finish_point(&entrypoint_name);
graph.compile_with_checkpointer(checkpointer)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::JunctureError;
use crate::node::NodeFnUpdate;
use crate::state::MessagesState;
type TestState = MessagesState;
type TestStateUpdate = <TestState as State>::Update;
#[test]
fn test_runtime_new() {
let runtime = Runtime::<TestState>::new();
assert!(runtime.previous.is_none());
assert!(runtime.checkpointer.is_none());
assert!(runtime.store.is_none());
}
#[test]
fn test_runtime_default() {
let runtime = Runtime::<TestState>::default();
assert!(runtime.previous.is_none());
assert!(runtime.checkpointer.is_none());
assert!(runtime.store.is_none());
}
#[test]
fn test_runtime_with_previous() {
let previous = serde_json::json!("previous_value");
let runtime = Runtime::<TestState>::new().with_previous(previous.clone());
assert_eq!(runtime.previous, Some(previous));
}
#[test]
fn test_runtime_from_entrypoint_config() {
let config = EntrypointConfig {
checkpointer: None,
store: None,
};
let runtime = Runtime::<TestState>::from_entrypoint_config(&config);
assert!(runtime.checkpointer.is_none());
assert!(runtime.store.is_none());
}
#[test]
fn test_runtime_clone() {
let runtime = Runtime::<TestState>::new();
let _cloned = runtime.clone();
assert!(runtime.previous.is_none());
assert!(runtime.checkpointer.is_none());
}
#[test]
fn test_compile_entrypoint_basic() {
let result = compile_entrypoint::<TestState, TestState, TestState, _>(
NodeFnUpdate(|_state: &TestState| async {
Ok::<TestStateUpdate, JunctureError>(TestStateUpdate::default())
}),
None,
);
result.unwrap();
}
#[test]
fn test_compile_entrypoint_with_config() {
let retry_policy = crate::graph::RetryPolicy {
max_attempts: 3,
..Default::default()
};
let config = TaskConfig {
retry_policy: Some(retry_policy),
cache_policy: None,
timeout: None,
name: Some("custom_entrypoint".to_string()),
};
let result = compile_entrypoint_with_config::<TestState, TestState, TestState, _>(
NodeFnUpdate(|_state: &TestState| async {
Ok::<TestStateUpdate, JunctureError>(TestStateUpdate::default())
}),
&config,
None,
);
result.unwrap();
}
}