use async_trait::async_trait;
use std::borrow::Cow;
use lash_core::PreparedContext;
use lash_core::plugin::{ContextError, TurnContextTransform, TurnTransformContext};
use lash_core::{Message, MessageOrigin, MessageRole, Part, PartKind};
use crate::ObservationalMemoryConfig;
use crate::constants::{
OBSERVATION_CONTEXT_INSTRUCTIONS, OBSERVATION_CONTEXT_PROMPT, OBSERVATION_CONTINUATION_HINT,
OBSERVATIONAL_MEMORY_PLUGIN_ID,
};
use crate::graph_state::{
active_unobserved_message_nodes, approx_message_nodes_tokens, approx_token_count,
build_graph_state,
};
use crate::host::OmRuntimeHost;
use crate::model::ActiveMemoryState;
use crate::transitions::maybe_advance_memory_state;
pub(crate) struct ObservationalMemoryTransform {
config: ObservationalMemoryConfig,
}
impl ObservationalMemoryTransform {
pub(crate) fn new(config: ObservationalMemoryConfig) -> Self {
Self { config }
}
}
#[async_trait]
impl TurnContextTransform for ObservationalMemoryTransform {
fn id(&self) -> &'static str {
"observational_memory.prepare_turn"
}
async fn transform(
&self,
ctx: &TurnTransformContext<'_>,
input: PreparedContext,
) -> Result<PreparedContext, ContextError> {
let graph = ctx.state.session_graph();
let om_state = build_graph_state(graph);
let pending_message_tokens = approx_message_nodes_tokens(&active_unobserved_message_nodes(
graph,
om_state
.active
.as_ref()
.and_then(|state| state.observed_through_message_id.as_deref()),
));
let active_observation_tokens = om_state
.active
.as_ref()
.map(|state| approx_token_count(&state.observations))
.unwrap_or(0);
let should_advance_memory = pending_message_tokens
>= self.config.observation_message_tokens
|| active_observation_tokens >= self.config.reflection_observation_tokens;
let graph = if should_advance_memory {
let mut graph = ctx.sessions.snapshot_current().await?.session_graph;
graph = maybe_advance_memory_state(
&self.config,
&OmRuntimeHost::new(
&ctx.session_id,
&ctx.session_graph,
ctx.direct_completions.clone(),
),
ctx.state.policy(),
graph,
)
.await?;
Cow::Owned(graph)
} else {
Cow::Borrowed(graph)
};
let Some(active) = build_graph_state(graph.as_ref()).active else {
return Ok(input);
};
if active.observations.trim().is_empty()
&& active.current_task.is_none()
&& active.suggested_response.is_none()
{
return Ok(input);
}
let input_messages = input.messages.as_slice();
let prefix_len = input_messages
.iter()
.take_while(|message| matches!(message.role, MessageRole::System))
.count();
let tail_start = memory_tail_start(
input_messages,
prefix_len,
active.observed_through_message_id.as_deref(),
);
let mut messages = Vec::new();
messages.extend_from_slice(&input_messages[..prefix_len]);
messages.extend(build_memory_context_messages(&active));
messages.extend_from_slice(&input_messages[tail_start..]);
let base = std::sync::Arc::new(messages);
let cache = std::sync::Arc::new(lash_core::BaseRenderCache::new());
Ok(PreparedContext {
messages: lash_core::MessageSequence::from_base(base).with_base_render_cache(cache),
..input
})
}
}
fn memory_tail_start(
input_messages: &[Message],
prefix_len: usize,
observed_through_message_id: Option<&str>,
) -> usize {
observed_through_message_id
.and_then(|message_id| {
input_messages
.iter()
.position(|message| message.id == message_id)
.map(|idx| idx + 1)
})
.unwrap_or(prefix_len)
.max(prefix_len)
}
fn build_memory_context_messages(active: &ActiveMemoryState) -> Vec<Message> {
let mut messages = Vec::new();
messages.push(plugin_message(
"om-memory-system",
MessageRole::System,
format!("{OBSERVATION_CONTEXT_PROMPT}\n\n{OBSERVATION_CONTEXT_INSTRUCTIONS}"),
));
let mut memory_block = String::from("<observations>\n");
memory_block.push_str(active.observations.trim());
memory_block.push_str("\n</observations>");
if let Some(current_task) = &active.current_task {
memory_block.push_str(&format!(
"\n\n<current-task>\n{}\n</current-task>",
current_task.trim()
));
}
if let Some(suggested_response) = &active.suggested_response {
memory_block.push_str(&format!(
"\n\n<suggested-response>\n{}\n</suggested-response>",
suggested_response.trim()
));
}
messages.push(plugin_message(
"om-memory-block",
MessageRole::System,
memory_block,
));
messages.push(plugin_message(
"om-memory-reminder",
MessageRole::User,
format!("<system-reminder>{OBSERVATION_CONTINUATION_HINT}</system-reminder>"),
));
messages
}
fn plugin_message(id: &str, role: MessageRole, content: String) -> Message {
Message {
id: id.to_string(),
role,
parts: lash_core::shared_parts(vec![Part {
id: format!("{id}.p0"),
kind: PartKind::Prose,
content,
attachment: None,
tool_call_id: None,
tool_name: None,
tool_replay: None,
prune_state: lash_core::PruneState::Intact,
reasoning_meta: None,
response_meta: None,
}]),
origin: Some(MessageOrigin::Plugin {
plugin_id: OBSERVATIONAL_MEMORY_PLUGIN_ID.to_string(),
transient: true,
}),
}
}
#[cfg(test)]
mod tail_start_tests {
use super::*;
fn msg(id: &str, role: MessageRole) -> Message {
plugin_message(id, role, "x".to_string())
}
#[test]
fn marker_after_prefix_starts_tail_right_after_marker() {
let messages = [
msg("s0", MessageRole::System),
msg("u1", MessageRole::User),
msg("a2", MessageRole::Assistant),
msg("u3", MessageRole::User),
];
assert_eq!(memory_tail_start(&messages, 1, Some("u1")), 2);
}
#[test]
fn marker_inside_prefix_does_not_re_include_prefix() {
let messages = [
msg("s0", MessageRole::System),
msg("s1", MessageRole::System),
msg("u2", MessageRole::User),
];
assert_eq!(memory_tail_start(&messages, 2, Some("s0")), 2);
}
#[test]
fn missing_marker_falls_back_to_prefix_len() {
let messages = [msg("s0", MessageRole::System), msg("u1", MessageRole::User)];
assert_eq!(memory_tail_start(&messages, 1, Some("nonexistent")), 1);
assert_eq!(memory_tail_start(&messages, 1, None), 1);
}
}