Skip to main content

lash_plugin_observational_memory/
lib.rs

1use std::sync::Arc;
2
3use lash_core::plugin::{
4    PluginError, PluginFactory, PluginLifecycleEvent, PluginRegistrar, PluginSessionContext,
5    SessionPlugin,
6};
7use lash_core::{SessionAppendNode, SessionStateChangedContext};
8
9mod constants;
10mod context_transform;
11mod graph_state;
12mod host;
13mod model;
14mod prompts;
15mod transitions;
16mod worker;
17
18pub use constants::{
19    ACTIVE_STATE_PLUGIN_TYPE, BUFFERED_OBSERVATION_PLUGIN_TYPE, BUFFERED_REFLECTION_PLUGIN_TYPE,
20    OBSERVATIONAL_MEMORY_PLUGIN_ID,
21};
22
23use context_transform::ObservationalMemoryTransform;
24use host::OmRuntimeHost;
25use transitions::{
26    maybe_buffer_observations, maybe_buffer_reflection, should_run_async_maintenance,
27};
28
29#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
30pub struct ObservationalMemoryConfig {
31    pub observation_message_tokens: usize,
32    pub observation_buffer_tokens: usize,
33    pub observation_block_after_tokens: usize,
34    pub observation_max_tokens_per_batch: usize,
35    pub previous_observer_tokens: usize,
36    pub reflection_observation_tokens: usize,
37    #[serde(default = "default_reflection_buffer_activation_bps")]
38    pub reflection_buffer_activation_bps: u16,
39    pub reflection_block_after_tokens: usize,
40}
41
42impl Default for ObservationalMemoryConfig {
43    fn default() -> Self {
44        Self {
45            observation_message_tokens: 30_000,
46            observation_buffer_tokens: 6_000,
47            observation_block_after_tokens: 36_000,
48            observation_max_tokens_per_batch: 10_000,
49            previous_observer_tokens: 2_000,
50            reflection_observation_tokens: 40_000,
51            reflection_buffer_activation_bps: default_reflection_buffer_activation_bps(),
52            reflection_block_after_tokens: 48_000,
53        }
54    }
55}
56
57impl ObservationalMemoryConfig {
58    pub fn observation_buffer_interval_tokens(&self) -> usize {
59        self.observation_buffer_tokens
60    }
61
62    pub fn observation_retention_tokens(&self) -> usize {
63        self.observation_buffer_tokens
64    }
65
66    pub fn reflection_buffer_activation_tokens(&self) -> usize {
67        self.reflection_observation_tokens
68            .saturating_mul(self.reflection_buffer_activation_bps as usize)
69            / 10_000
70    }
71}
72
73const fn default_reflection_buffer_activation_bps() -> u16 {
74    5_000
75}
76
77pub fn active_memory_state_node(
78    body: impl serde::Serialize,
79) -> Result<SessionAppendNode, serde_json::Error> {
80    Ok(SessionAppendNode::plugin(
81        ACTIVE_STATE_PLUGIN_TYPE,
82        serde_json::to_value(body)?,
83    ))
84}
85
86#[derive(Clone, Debug)]
87pub struct ObservationalMemoryPluginFactory {
88    config: ObservationalMemoryConfig,
89}
90
91impl ObservationalMemoryPluginFactory {
92    pub fn new(config: ObservationalMemoryConfig) -> Self {
93        Self { config }
94    }
95}
96
97impl Default for ObservationalMemoryPluginFactory {
98    fn default() -> Self {
99        Self::new(ObservationalMemoryConfig::default())
100    }
101}
102
103impl PluginFactory for ObservationalMemoryPluginFactory {
104    fn id(&self) -> &'static str {
105        OBSERVATIONAL_MEMORY_PLUGIN_ID
106    }
107
108    fn build(&self, _ctx: &PluginSessionContext) -> Result<Arc<dyn SessionPlugin>, PluginError> {
109        Ok(Arc::new(ObservationalMemoryPlugin {
110            config: self.config.clone(),
111        }))
112    }
113}
114
115struct ObservationalMemoryPlugin {
116    config: ObservationalMemoryConfig,
117}
118
119impl SessionPlugin for ObservationalMemoryPlugin {
120    fn id(&self) -> &'static str {
121        OBSERVATIONAL_MEMORY_PLUGIN_ID
122    }
123
124    fn register(&self, reg: &mut PluginRegistrar) -> Result<(), PluginError> {
125        reg.context().prepare_turn(
126            100,
127            Arc::new(ObservationalMemoryTransform::new(self.config.clone())),
128        );
129
130        let config = self.config.clone();
131        reg.session()
132            .on_event(observational_memory_event_hook(config));
133
134        Ok(())
135    }
136}
137
138fn observational_memory_event_hook(
139    config: ObservationalMemoryConfig,
140) -> lash_core::plugin::PluginLifecycleEventHook {
141    Arc::new(move |event| {
142        let config = config.clone();
143        Box::pin(async move {
144            if let PluginLifecycleEvent::TurnPersisted(ctx) = event {
145                maybe_spawn_post_persist_memory_maintenance(config, *ctx).await?;
146            }
147            Ok(())
148        })
149    })
150}
151
152async fn maybe_spawn_post_persist_memory_maintenance(
153    config: ObservationalMemoryConfig,
154    ctx: SessionStateChangedContext<'_>,
155) -> Result<(), PluginError> {
156    let graph = ctx.state.session_graph();
157    if !should_run_async_maintenance(&config, graph) {
158        return Ok(());
159    }
160    run_async_maintenance(config, graph, &ctx).await
161}
162
163async fn run_async_maintenance(
164    config: ObservationalMemoryConfig,
165    graph: &lash_core::SessionGraph,
166    ctx: &SessionStateChangedContext<'_>,
167) -> Result<(), PluginError> {
168    let om_host = OmRuntimeHost::new(
169        &ctx.session_id,
170        &ctx.session_graph,
171        ctx.direct_completions.clone(),
172    );
173    maybe_buffer_observations(&config, &om_host, ctx.state.policy(), graph).await?;
174    maybe_buffer_reflection(&config, &om_host, ctx.state.policy(), graph).await?;
175    Ok(())
176}
177
178#[cfg(test)]
179mod tests;