episodic 0.2.3

Reusable Observational Memory core models and pure transforms.
Documentation
use serde::{Deserialize, Serialize};

use crate::{OmObserverRequest, OmReflectorRequest};

pub const OM_PROMPT_CONTRACT_NAME: &str = "episodic.om.prompt";
pub const OM_PROMPT_CONTRACT_VERSION: &str = "2.0.0";
pub const OM_PROTOCOL_VERSION: &str = "om-v2";
pub const OM_PROMPT_CONTRACT_MARKERS_XML_BLOCK: &str = r#"<contract-name>episodic.om.prompt</contract-name>
<contract-version>2.0.0</contract-version>
<protocol-version>om-v2</protocol-version>"#;
pub const OM_PROMPT_CONTRACT_MARKERS_XML_INLINE: &str = "<contract-name>episodic.om.prompt</contract-name>, <contract-version>2.0.0</contract-version>, and <protocol-version>om-v2</protocol-version>";

#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum OmPromptRequestKind {
    ObserverSingle,
    ObserverMulti,
    Reflector,
}

#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct OmPromptContractHeader {
    pub contract_name: String,
    pub contract_version: String,
    pub protocol_version: String,
    pub request_kind: OmPromptRequestKind,
    pub scope: String,
    pub scope_key: String,
}

#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct OmPromptLimitsV2 {
    pub max_output_tokens: u32,
    pub observation_max_chars: Option<usize>,
    pub reflection_max_chars: Option<usize>,
}

#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct OmPromptOutputContractV2 {
    pub format: String,
    pub required_sections: Vec<String>,
    pub continuation_enabled: bool,
}

#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct OmObserverPromptContractV2 {
    pub header: OmPromptContractHeader,
    pub known_message_ids: Vec<String>,
    #[serde(default, skip_serializing_if = "Option::is_none")]
    pub preferred_thread_id: Option<String>,
    pub has_other_conversation_context: bool,
    pub skip_continuation_hints: bool,
    pub limits: OmPromptLimitsV2,
    pub output_contract: OmPromptOutputContractV2,
}

#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct OmReflectorPromptContractV2 {
    pub header: OmPromptContractHeader,
    pub generation_count: u32,
    pub compression_level: u8,
    pub skip_continuation_hints: bool,
    pub limits: OmPromptLimitsV2,
    pub output_contract: OmPromptOutputContractV2,
}

#[must_use]
fn build_observer_prompt_contract(
    request: &OmObserverRequest,
    request_kind: OmPromptRequestKind,
    known_message_ids: &[String],
    skip_continuation_hints: bool,
    preferred_thread_id: Option<&str>,
    observation_max_chars: usize,
) -> OmObserverPromptContractV2 {
    let mut ids = known_message_ids
        .iter()
        .map(|id| id.trim())
        .filter(|id| !id.is_empty())
        .map(ToString::to_string)
        .collect::<Vec<_>>();
    ids.sort();
    ids.dedup();

    OmObserverPromptContractV2 {
        header: OmPromptContractHeader {
            contract_name: OM_PROMPT_CONTRACT_NAME.to_string(),
            contract_version: OM_PROMPT_CONTRACT_VERSION.to_string(),
            protocol_version: OM_PROTOCOL_VERSION.to_string(),
            request_kind,
            scope: request.scope.as_str().to_string(),
            scope_key: request.scope_key.clone(),
        },
        known_message_ids: ids,
        preferred_thread_id: preferred_thread_id.map(ToString::to_string),
        has_other_conversation_context: request
            .other_conversations
            .as_deref()
            .is_some_and(|value| !value.trim().is_empty()),
        skip_continuation_hints,
        limits: OmPromptLimitsV2 {
            max_output_tokens: request.model.max_output_tokens,
            observation_max_chars: Some(observation_max_chars),
            reflection_max_chars: None,
        },
        output_contract: OmPromptOutputContractV2 {
            format: "xml".to_string(),
            required_sections: if skip_continuation_hints {
                vec![
                    "contract-name".to_string(),
                    "contract-version".to_string(),
                    "protocol-version".to_string(),
                    "observations".to_string(),
                ]
            } else {
                vec![
                    "contract-name".to_string(),
                    "contract-version".to_string(),
                    "protocol-version".to_string(),
                    "observations".to_string(),
                    "current-task".to_string(),
                    "suggested-response".to_string(),
                ]
            },
            continuation_enabled: !skip_continuation_hints,
        },
    }
}

#[must_use]
pub fn build_observer_prompt_contract_v2(
    request: &OmObserverRequest,
    known_message_ids: &[String],
    skip_continuation_hints: bool,
    preferred_thread_id: Option<&str>,
    observation_max_chars: usize,
) -> OmObserverPromptContractV2 {
    build_observer_prompt_contract(
        request,
        OmPromptRequestKind::ObserverSingle,
        known_message_ids,
        skip_continuation_hints,
        preferred_thread_id,
        observation_max_chars,
    )
}

#[must_use]
pub fn build_multi_thread_observer_prompt_contract_v2(
    request: &OmObserverRequest,
    known_message_ids: &[String],
    skip_continuation_hints: bool,
    preferred_thread_id: Option<&str>,
    observation_max_chars: usize,
) -> OmObserverPromptContractV2 {
    build_observer_prompt_contract(
        request,
        OmPromptRequestKind::ObserverMulti,
        known_message_ids,
        skip_continuation_hints,
        preferred_thread_id,
        observation_max_chars,
    )
}

#[must_use]
pub fn build_reflector_prompt_contract_v2(
    request: &OmReflectorRequest,
    compression_level: u8,
    skip_continuation_hints: bool,
    reflection_max_chars: usize,
) -> OmReflectorPromptContractV2 {
    OmReflectorPromptContractV2 {
        header: OmPromptContractHeader {
            contract_name: OM_PROMPT_CONTRACT_NAME.to_string(),
            contract_version: OM_PROMPT_CONTRACT_VERSION.to_string(),
            protocol_version: OM_PROTOCOL_VERSION.to_string(),
            request_kind: OmPromptRequestKind::Reflector,
            scope: request.scope.as_str().to_string(),
            scope_key: request.scope_key.clone(),
        },
        generation_count: request.generation_count,
        compression_level,
        skip_continuation_hints,
        limits: OmPromptLimitsV2 {
            max_output_tokens: request.model.max_output_tokens,
            observation_max_chars: None,
            reflection_max_chars: Some(reflection_max_chars),
        },
        output_contract: OmPromptOutputContractV2 {
            format: "xml".to_string(),
            required_sections: if skip_continuation_hints {
                vec![
                    "contract-name".to_string(),
                    "contract-version".to_string(),
                    "protocol-version".to_string(),
                    "observations".to_string(),
                ]
            } else {
                vec![
                    "contract-name".to_string(),
                    "contract-version".to_string(),
                    "protocol-version".to_string(),
                    "observations".to_string(),
                    "current-task".to_string(),
                    "suggested-response".to_string(),
                ]
            },
            continuation_enabled: !skip_continuation_hints,
        },
    }
}