lash_plugin_observational_memory/
lib.rs1use std::sync::Arc;
2
3use lash_core::plugin::{
4 PluginError, PluginFactory, PluginLifecycleEvent, PluginRegistrar, PluginSessionContext,
5 SessionPlugin,
6};
7use lash_core::{SessionAppendNode, SessionStateChangedContext};
8
9mod constants;
10mod context_transform;
11mod graph_state;
12mod host;
13mod model;
14mod prompts;
15mod transitions;
16mod worker;
17
18pub use constants::{
19 ACTIVE_STATE_PLUGIN_TYPE, BUFFERED_OBSERVATION_PLUGIN_TYPE, BUFFERED_REFLECTION_PLUGIN_TYPE,
20 OBSERVATIONAL_MEMORY_PLUGIN_ID,
21};
22
23use context_transform::ObservationalMemoryTransform;
24use host::OmRuntimeHost;
25use transitions::{
26 maybe_buffer_observations, maybe_buffer_reflection, should_run_async_maintenance,
27};
28
29#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
30pub struct ObservationalMemoryConfig {
31 pub observation_message_tokens: usize,
32 pub observation_buffer_tokens: usize,
33 pub observation_block_after_tokens: usize,
34 pub observation_max_tokens_per_batch: usize,
35 pub previous_observer_tokens: usize,
36 pub reflection_observation_tokens: usize,
37 #[serde(default = "default_reflection_buffer_activation_bps")]
38 pub reflection_buffer_activation_bps: u16,
39 pub reflection_block_after_tokens: usize,
40}
41
42impl Default for ObservationalMemoryConfig {
43 fn default() -> Self {
44 Self {
45 observation_message_tokens: 30_000,
46 observation_buffer_tokens: 6_000,
47 observation_block_after_tokens: 36_000,
48 observation_max_tokens_per_batch: 10_000,
49 previous_observer_tokens: 2_000,
50 reflection_observation_tokens: 40_000,
51 reflection_buffer_activation_bps: default_reflection_buffer_activation_bps(),
52 reflection_block_after_tokens: 48_000,
53 }
54 }
55}
56
57impl ObservationalMemoryConfig {
58 pub fn observation_buffer_interval_tokens(&self) -> usize {
59 self.observation_buffer_tokens
60 }
61
62 pub fn observation_retention_tokens(&self) -> usize {
63 self.observation_buffer_tokens
64 }
65
66 pub fn reflection_buffer_activation_tokens(&self) -> usize {
67 self.reflection_observation_tokens
68 .saturating_mul(self.reflection_buffer_activation_bps as usize)
69 / 10_000
70 }
71}
72
73const fn default_reflection_buffer_activation_bps() -> u16 {
74 5_000
75}
76
77pub fn active_memory_state_node(
78 body: impl serde::Serialize,
79) -> Result<SessionAppendNode, serde_json::Error> {
80 Ok(SessionAppendNode::plugin(
81 ACTIVE_STATE_PLUGIN_TYPE,
82 serde_json::to_value(body)?,
83 ))
84}
85
86#[derive(Clone, Debug)]
87pub struct ObservationalMemoryPluginFactory {
88 config: ObservationalMemoryConfig,
89}
90
91impl ObservationalMemoryPluginFactory {
92 pub fn new(config: ObservationalMemoryConfig) -> Self {
93 Self { config }
94 }
95}
96
97impl Default for ObservationalMemoryPluginFactory {
98 fn default() -> Self {
99 Self::new(ObservationalMemoryConfig::default())
100 }
101}
102
103impl PluginFactory for ObservationalMemoryPluginFactory {
104 fn id(&self) -> &'static str {
105 OBSERVATIONAL_MEMORY_PLUGIN_ID
106 }
107
108 fn build(&self, _ctx: &PluginSessionContext) -> Result<Arc<dyn SessionPlugin>, PluginError> {
109 Ok(Arc::new(ObservationalMemoryPlugin {
110 config: self.config.clone(),
111 }))
112 }
113}
114
115struct ObservationalMemoryPlugin {
116 config: ObservationalMemoryConfig,
117}
118
119impl SessionPlugin for ObservationalMemoryPlugin {
120 fn id(&self) -> &'static str {
121 OBSERVATIONAL_MEMORY_PLUGIN_ID
122 }
123
124 fn register(&self, reg: &mut PluginRegistrar) -> Result<(), PluginError> {
125 reg.context().prepare_turn(
126 100,
127 Arc::new(ObservationalMemoryTransform::new(self.config.clone())),
128 );
129
130 let config = self.config.clone();
131 reg.session()
132 .on_event(observational_memory_event_hook(config));
133
134 Ok(())
135 }
136}
137
138fn observational_memory_event_hook(
139 config: ObservationalMemoryConfig,
140) -> lash_core::plugin::PluginLifecycleEventHook {
141 Arc::new(move |event| {
142 let config = config.clone();
143 Box::pin(async move {
144 if let PluginLifecycleEvent::TurnPersisted(ctx) = event {
145 maybe_spawn_post_persist_memory_maintenance(config, *ctx).await?;
146 }
147 Ok(())
148 })
149 })
150}
151
152async fn maybe_spawn_post_persist_memory_maintenance(
153 config: ObservationalMemoryConfig,
154 ctx: SessionStateChangedContext<'_>,
155) -> Result<(), PluginError> {
156 let graph = ctx.state.session_graph();
157 if !should_run_async_maintenance(&config, graph) {
158 return Ok(());
159 }
160 run_async_maintenance(config, graph, &ctx).await
161}
162
163async fn run_async_maintenance(
164 config: ObservationalMemoryConfig,
165 graph: &lash_core::SessionGraph,
166 ctx: &SessionStateChangedContext<'_>,
167) -> Result<(), PluginError> {
168 let om_host = OmRuntimeHost::new(
169 &ctx.session_id,
170 &ctx.session_graph,
171 ctx.direct_completions.clone(),
172 );
173 maybe_buffer_observations(&config, &om_host, ctx.state.policy(), graph).await?;
174 maybe_buffer_reflection(&config, &om_host, ctx.state.policy(), graph).await?;
175 Ok(())
176}
177
178#[cfg(test)]
179mod tests;