lash-plugin-observational-memory 0.1.0-alpha.36

Observational-memory plugin for the lash agent runtime.
Documentation
use std::collections::HashMap;

use lash_core::{MessageRole, SessionGraph};

use crate::constants::{
    ACTIVE_STATE_PLUGIN_TYPE, BUFFERED_OBSERVATION_PLUGIN_TYPE, BUFFERED_REFLECTION_PLUGIN_TYPE,
};
use crate::model::{
    ActiveMemoryNode, ActiveMemoryState, BufferedObservationNode, BufferedObservationState,
    BufferedReflectionNode, BufferedReflectionState, MessageNode, ObservedMessageNode,
    OmGraphState,
};
use crate::prompts::format_message_for_observer;

pub(crate) fn build_graph_state(graph: &SessionGraph) -> OmGraphState {
    let mut state = OmGraphState::default();
    for node in graph.active_path_nodes() {
        let Some((kind, _)) = node.plugin() else {
            continue;
        };
        match kind {
            ACTIVE_STATE_PLUGIN_TYPE => {
                let Some(active) = node.plugin_body::<ActiveMemoryNode>() else {
                    continue;
                };
                state.active = Some(ActiveMemoryState {
                    state_node_id: node.node_id.clone(),
                    observed_through_message_id: Some(active.observed_through_message_id),
                    observations: active.observations,
                    current_task: active.current_task,
                    suggested_response: active.suggested_response,
                });
                state.buffered_observations.clear();
                state.buffered_reflection = None;
            }
            BUFFERED_OBSERVATION_PLUGIN_TYPE => {
                let Some(buffered) = node.plugin_body::<BufferedObservationNode>() else {
                    continue;
                };
                if state.buffered_observations.iter().any(|chunk| {
                    chunk.observed_through_message_id == buffered.observed_through_message_id
                }) {
                    continue;
                }
                state.buffered_observations.push(BufferedObservationState {
                    observed_through_message_id: buffered.observed_through_message_id,
                    observations: buffered.observations,
                    current_task: buffered.current_task,
                    suggested_response: buffered.suggested_response,
                });
            }
            BUFFERED_REFLECTION_PLUGIN_TYPE => {
                let Some(buffered) = node.plugin_body::<BufferedReflectionNode>() else {
                    continue;
                };
                let Some(active) = state.active.as_ref() else {
                    continue;
                };
                if buffered.source_state_node_id != active.state_node_id {
                    continue;
                }
                state.buffered_reflection = Some(BufferedReflectionState {
                    source_state_node_id: buffered.source_state_node_id,
                    observed_through_message_id: buffered.observed_through_message_id,
                    observations: buffered.observations,
                    current_task: buffered.current_task,
                    suggested_response: buffered.suggested_response,
                });
            }
            _ => {}
        }
    }
    state
}

pub(crate) fn active_unobserved_message_nodes(
    graph: &SessionGraph,
    observed_through_message_id: Option<&str>,
) -> Vec<MessageNode> {
    let mut seen_observed = observed_through_message_id.is_none();
    graph
        .active_path_nodes()
        .into_iter()
        .filter_map(|node| {
            let message = node.message()?;
            if matches!(message.role, MessageRole::System) {
                return None;
            }
            if !seen_observed {
                if observed_through_message_id == Some(message.id.as_str()) {
                    seen_observed = true;
                }
                return None;
            }
            Some(MessageNode {
                timestamp: node.timestamp.clone(),
                message,
            })
        })
        .collect()
}

pub(crate) fn retained_message_tokens_by_message_id<N: ObservedMessageNode>(
    messages: &[N],
) -> HashMap<&str, usize> {
    let mut retained = HashMap::new();
    let mut suffix_tokens = 0usize;
    for message in messages.iter().rev() {
        retained.insert(message.message().id.as_str(), suffix_tokens);
        suffix_tokens = suffix_tokens.saturating_add(approx_message_tokens(message));
    }
    retained
}

pub(crate) fn prefix_len_leaving_tail_budget<N: ObservedMessageNode>(
    messages: &[N],
    tail_budget_tokens: usize,
) -> usize {
    if messages.is_empty() {
        return 0;
    }
    if tail_budget_tokens == 0 {
        return messages.len();
    }
    let mut suffix_tokens = 0usize;
    for (idx, message) in messages.iter().enumerate().rev() {
        suffix_tokens = suffix_tokens.saturating_add(approx_message_tokens(message));
        if suffix_tokens > tail_budget_tokens {
            return idx + 1;
        }
    }
    0
}

pub(crate) fn prefix_len_covering_tokens<N: ObservedMessageNode>(
    messages: &[N],
    target_tokens: usize,
) -> Option<usize> {
    if target_tokens == 0 {
        return Some(0);
    }
    let mut total = 0usize;
    for (idx, message) in messages.iter().enumerate() {
        total = total.saturating_add(approx_message_tokens(message));
        if total >= target_tokens {
            return Some(idx + 1);
        }
    }
    None
}

pub(crate) fn split_message_batches<N: ObservedMessageNode + Clone>(
    messages: &[N],
    max_tokens_per_batch: usize,
) -> Vec<Vec<N>> {
    let mut batches = Vec::new();
    let mut current = Vec::new();
    let mut current_tokens = 0usize;

    for message in messages {
        let tokens = approx_message_tokens(message).max(1);
        if !current.is_empty() && current_tokens + tokens > max_tokens_per_batch {
            batches.push(current);
            current = Vec::new();
            current_tokens = 0;
        }
        current.push(message.clone());
        current_tokens += tokens;
    }

    if !current.is_empty() {
        batches.push(current);
    }
    batches
}

pub(crate) fn approx_message_nodes_tokens<N: ObservedMessageNode>(messages: &[N]) -> usize {
    messages.iter().map(approx_message_tokens).sum()
}

pub(crate) fn approx_message_tokens<N: ObservedMessageNode>(message: &N) -> usize {
    approx_token_count(&format_message_for_observer(message))
}

pub(crate) fn approx_token_count(text: &str) -> usize {
    text.chars().count().div_ceil(4)
}