use std::fmt::Debug;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use thiserror::Error;
use uuid::Uuid;
use crate::types::ObservationMemory;
#[derive(Debug, Error)]
pub enum ExtractionError {
#[error("extractor configuration error: {0}")]
Config(String),
#[error("extractor transport error: {0}")]
Transport(String),
#[error("extractor response parse error: {0}")]
Parse(String),
#[error("extractor budget exceeded: {0}")]
BudgetExceeded(String),
#[error("extraction failed: {0}")]
Other(String),
}
pub type ExtractionResult<T> = Result<T, ExtractionError>;
#[derive(Debug, Clone)]
pub struct ExtractionMessage {
pub role: String,
pub content: String,
pub event_time: Option<DateTime<Utc>>,
}
#[async_trait]
pub trait ObservationExtractor: Send + Sync + Debug {
async fn extract(
&self,
namespace_id: Uuid,
episode_id: Uuid,
messages: &[ExtractionMessage],
) -> ExtractionResult<Vec<ObservationMemory>>;
async fn extract_batch(
&self,
namespace_id: Uuid,
episode_ids: &[Uuid],
episodes: Vec<&[ExtractionMessage]>,
) -> ExtractionResult<Vec<Vec<ObservationMemory>>> {
if episode_ids.len() != episodes.len() {
return Err(ExtractionError::Other(format!(
"extract_batch: episode_ids ({}) and episodes ({}) length mismatch",
episode_ids.len(),
episodes.len(),
)));
}
let mut out = Vec::with_capacity(episodes.len());
for (eid, ep) in episode_ids.iter().zip(episodes) {
out.push(self.extract(namespace_id, *eid, ep).await?);
}
Ok(out)
}
}
#[derive(Debug, Default, Clone, Copy)]
pub struct NoopExtractor;
#[async_trait]
impl ObservationExtractor for NoopExtractor {
async fn extract(
&self,
_namespace_id: Uuid,
_episode_id: Uuid,
_messages: &[ExtractionMessage],
) -> ExtractionResult<Vec<ObservationMemory>> {
Ok(Vec::new())
}
}
#[cfg(feature = "observation-extraction")]
mod prompt_v1 {
use super::{ExtractionMessage, ObservationMemory};
use chrono::{DateTime, Utc};
use serde::Deserialize;
use std::fmt::Write as _;
use uuid::Uuid;
pub const EXTRACTION_PROMPT_V1: &str = "You are a structured-data extractor. \
Given recalled conversation memories between a user and an assistant, \
extract every **countable entity instance** mentioned by the USER (not the \
assistant's suggestions unless the user confirmed them).
A countable entity is something that could answer a \"how many\", \"how often\", \
or \"list every\" question: items purchased, hours spent on activities, places \
visited, books read, projects worked on, meals cooked, clothing items, pets, \
tanks, plants, games played, etc.
For each instance, output a JSON object:
{
\"entity_type\": \"<category, e.g. 'game_played', 'book_read', 'place_visited'>\",
\"instance\": \"<specific name, e.g. 'Assassin's Creed Odyssey'>\",
\"action\": \"<what the user did, e.g. 'played', 'read', 'visited'>\",
\"quantity\": <numeric value if stated, else null>,
\"unit\": \"<unit if applicable, e.g. 'hours', 'pages', else null>\",
\"confidence\": <0.0-1.0, lower for hedged/hypothetical mentions>
}
Rules:
- Only extract things the USER actually did, owns, or experienced. Exclude \
assistant suggestions that the user did not confirm, hypotheticals, and \
\"I might...\" / \"I'm thinking about...\" statements.
- If the user mentions doing the same thing multiple times with different \
quantities (e.g., \"played 25 hours\" then later \"played another 30 hours\"), \
extract EACH as a separate instance with its own quantity.
- Set confidence < 0.5 for anything hedged, uncertain, merely planned but \
not confirmed, or ambiguous.
- Include items the user needs to pick up, return, buy, etc. — these are \
countable actions even if not yet completed.
- Pay attention to whether something was ACTUALLY done vs merely MENTIONED \
or SUGGESTED. \"I bought boots\" = extract. \"You could try boots\" from the \
assistant without user confirmation = do NOT extract.
- If no countable entities are found, return an empty array: []
Output ONLY a JSON array of objects. No prose, no explanation, no markdown fences.";
pub(super) fn user_message(messages: &[ExtractionMessage]) -> String {
if messages.is_empty() {
return "[No conversation memories provided.]".to_string();
}
let mut body = String::new();
for m in messages {
let date = m.event_time.map_or_else(
|| "unknown".to_string(),
|t| t.format("%Y-%m-%d").to_string(),
);
if m.role.is_empty() {
let _ = writeln!(body, "[{date}] {}", m.content);
} else {
let _ = writeln!(body, "[{date}] {}: {}", m.role, m.content);
}
}
format!("--- Recalled memories ---\n{body}--- End memories ---")
}
pub(super) fn system_prompt() -> &'static str {
EXTRACTION_PROMPT_V1
}
pub(super) fn build_prompt(messages: &[ExtractionMessage]) -> String {
format!("{}\n\n{}", system_prompt(), user_message(messages))
}
#[derive(Debug, Deserialize)]
pub(super) struct RawObservation {
pub(super) entity_type: String,
pub(super) instance: String,
pub(super) action: String,
#[serde(default)]
pub(super) quantity: Option<f64>,
#[serde(default)]
pub(super) unit: Option<String>,
#[serde(default = "default_raw_confidence")]
pub(super) confidence: f32,
}
pub(super) fn default_raw_confidence() -> f32 {
0.8
}
pub(super) fn parse_response(text: &str) -> Vec<RawObservation> {
let trimmed = text.trim();
let no_fence = strip_markdown_fence(trimmed);
let bracket_start = no_fence.find('[');
let bracket_end = no_fence.rfind(']');
let slice = match (bracket_start, bracket_end) {
(Some(s), Some(e)) if e > s => &no_fence[s..=e],
_ => return Vec::new(),
};
serde_json::from_str(slice).unwrap_or_default()
}
pub(super) fn strip_markdown_fence(s: &str) -> &str {
let Some(start) = s.find("```") else {
return s;
};
let after_open = &s[start + 3..];
let after_lang = after_open
.strip_prefix("json")
.unwrap_or(after_open)
.trim_start();
let Some(close_rel) = after_lang.rfind("```") else {
return after_lang.trim();
};
after_lang[..close_rel].trim()
}
pub(super) fn raw_to_observation(
raw: RawObservation,
namespace_id: Uuid,
episode_id: Uuid,
event_time: Option<DateTime<Utc>>,
) -> ObservationMemory {
let content = format_observation_content(&raw);
let mut obs = ObservationMemory::new(
namespace_id,
episode_id,
raw.entity_type,
raw.instance,
raw.action,
content,
);
obs.quantity = raw.quantity;
obs.unit = raw.unit;
obs.confidence = raw.confidence.clamp(0.0, 1.0);
obs.event_time = event_time;
obs
}
fn format_observation_content(raw: &RawObservation) -> String {
let base = format!("{} {}", raw.action, raw.instance);
match (raw.quantity, raw.unit.as_deref()) {
(Some(q), Some(u)) => format!("{base} ({q} {u})"),
(Some(q), None) => format!("{base} ({q})"),
(None, Some(u)) => format!("{base} ({u})"),
(None, None) => base,
}
}
}
#[cfg(feature = "observation-extraction")]
pub use prompt_v1::EXTRACTION_PROMPT_V1;
#[cfg(feature = "legacy-anthropic-extractor")]
mod legacy_anthropic {
use super::prompt_v1::{parse_response, raw_to_observation, system_prompt, user_message};
use super::{
ExtractionError, ExtractionMessage, ExtractionResult, ObservationExtractor,
ObservationMemory,
};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::time::Duration;
use uuid::Uuid;
const DEFAULT_MODEL: &str = "claude-haiku-4-5-20251001";
const DEFAULT_MAX_TOKENS: u32 = 4096;
const DEFAULT_TIMEOUT_SECS: u64 = 60;
pub(super) const ANTHROPIC_VERSION: &str = "2023-06-01";
#[derive(Debug, Clone)]
pub struct LegacyAnthropicExtractor {
client: reqwest::Client,
api_key: String,
model: String,
max_tokens: u32,
base_url: String,
prompt_caching_enabled: bool,
}
impl LegacyAnthropicExtractor {
pub fn from_env() -> ExtractionResult<Self> {
let api_key = std::env::var("ANTHROPIC_API_KEY")
.map_err(|_| ExtractionError::Config("ANTHROPIC_API_KEY env var not set".into()))?;
Self::new(api_key)
}
pub fn new(api_key: impl Into<String>) -> ExtractionResult<Self> {
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECS))
.build()
.map_err(|e| ExtractionError::Config(format!("http client build: {e}")))?;
Ok(Self {
client,
api_key: api_key.into(),
model: DEFAULT_MODEL.into(),
max_tokens: DEFAULT_MAX_TOKENS,
base_url: "https://api.anthropic.com".into(),
prompt_caching_enabled: true,
})
}
#[must_use]
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.model = model.into();
self
}
#[must_use]
pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
self.base_url = base_url.into();
self
}
#[must_use]
pub fn without_prompt_caching(mut self) -> Self {
self.prompt_caching_enabled = false;
self
}
pub(super) fn client(&self) -> &reqwest::Client {
&self.client
}
pub(super) fn api_key(&self) -> &str {
&self.api_key
}
pub(super) fn model(&self) -> &str {
&self.model
}
pub(super) fn max_tokens(&self) -> u32 {
self.max_tokens
}
pub(super) fn base_url(&self) -> &str {
&self.base_url
}
pub(super) fn prompt_caching_enabled(&self) -> bool {
self.prompt_caching_enabled
}
#[cfg(test)]
pub(super) fn build_prompt(messages: &[ExtractionMessage]) -> String {
super::prompt_v1::build_prompt(messages)
}
}
#[derive(Debug, Deserialize)]
struct AnthropicResponse {
content: Vec<AnthropicContentBlock>,
}
#[derive(Debug, Deserialize)]
struct AnthropicContentBlock {
#[serde(rename = "type")]
block_type: String,
#[serde(default)]
text: String,
}
#[derive(Debug, Serialize)]
struct AnthropicRequest<'a> {
model: &'a str,
max_tokens: u32,
temperature: f32,
#[serde(skip_serializing_if = "Option::is_none")]
system: Option<Vec<SystemBlock<'a>>>,
messages: Vec<AnthropicMessage<'a>>,
}
#[derive(Debug, Serialize)]
pub(super) struct AnthropicMessage<'a> {
pub(super) role: &'a str,
pub(super) content: &'a str,
}
#[derive(Debug, Serialize)]
pub(super) struct SystemBlock<'a> {
#[serde(rename = "type")]
pub(super) block_type: &'static str,
pub(super) text: &'a str,
#[serde(skip_serializing_if = "Option::is_none")]
pub(super) cache_control: Option<CacheControl>,
}
#[derive(Debug, Serialize)]
pub(super) struct CacheControl {
#[serde(rename = "type")]
pub(super) cache_type: &'static str,
}
#[async_trait]
impl ObservationExtractor for LegacyAnthropicExtractor {
async fn extract(
&self,
namespace_id: Uuid,
episode_id: Uuid,
messages: &[ExtractionMessage],
) -> ExtractionResult<Vec<ObservationMemory>> {
let user_msg = user_message(messages);
let last_event_time = messages.iter().filter_map(|m| m.event_time).max();
let cache_control = if self.prompt_caching_enabled {
Some(CacheControl {
cache_type: "ephemeral",
})
} else {
None
};
let system_blocks = vec![SystemBlock {
block_type: "text",
text: system_prompt(),
cache_control,
}];
let req = AnthropicRequest {
model: &self.model,
max_tokens: self.max_tokens,
temperature: 0.0,
system: Some(system_blocks),
messages: vec![AnthropicMessage {
role: "user",
content: &user_msg,
}],
};
let url = format!("{}/v1/messages", self.base_url.trim_end_matches('/'));
let response = self
.client
.post(&url)
.header("x-api-key", &self.api_key)
.header("anthropic-version", ANTHROPIC_VERSION)
.header("content-type", "application/json")
.json(&req)
.send()
.await
.map_err(|e| ExtractionError::Transport(e.to_string()))?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(ExtractionError::Transport(format!("HTTP {status}: {body}")));
}
let parsed: AnthropicResponse = response
.json()
.await
.map_err(|e| ExtractionError::Parse(e.to_string()))?;
let text = parsed
.content
.into_iter()
.find(|b| b.block_type == "text")
.map(|b| b.text)
.unwrap_or_default();
let raws = parse_response(&text);
Ok(raws
.into_iter()
.map(|r| raw_to_observation(r, namespace_id, episode_id, last_event_time))
.collect())
}
}
#[cfg(test)]
#[allow(
clippy::bind_instead_of_map,
reason = "test code: `.and_then(|e| Ok(e))` is intentional in `new_rejects_when_api_key_lookup_fails` — it documents the constructor's contract that key shape is not validated"
)]
mod tests {
use super::super::prompt_v1::{EXTRACTION_PROMPT_V1, RawObservation};
use super::*;
use chrono::{DateTime, Utc};
use wiremock::matchers::{body_partial_json, header, method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
fn anthropic_response_body(text: &str) -> serde_json::Value {
serde_json::json!({
"id": "msg_test",
"type": "message",
"role": "assistant",
"model": "claude-haiku-4-5-20251001",
"content": [{"type": "text", "text": text}],
"stop_reason": "end_turn",
"usage": {"input_tokens": 0, "output_tokens": 0},
})
}
#[tokio::test]
async fn extractor_parses_successful_response() {
let server = MockServer::start().await;
let canned = serde_json::to_string(&serde_json::json!([
{
"entity_type": "game_played",
"instance": "Assassin's Creed Odyssey",
"action": "played",
"quantity": 70,
"unit": "hours",
"confidence": 0.9
},
{
"entity_type": "book_read",
"instance": "Dune",
"action": "read",
"quantity": null,
"unit": null,
"confidence": 0.8
}
]))
.unwrap();
Mock::given(method("POST"))
.and(path("/v1/messages"))
.and(header("x-api-key", "test-key"))
.and(header("anthropic-version", ANTHROPIC_VERSION))
.respond_with(
ResponseTemplate::new(200).set_body_json(anthropic_response_body(&canned)),
)
.mount(&server)
.await;
let extractor = LegacyAnthropicExtractor::new("test-key")
.unwrap()
.with_base_url(server.uri());
let ns = Uuid::new_v4();
let ep = Uuid::new_v4();
let result = extractor
.extract(
ns,
ep,
&[ExtractionMessage {
role: "user".into(),
content: "I played AC Odyssey for 70 hours".into(),
event_time: None,
}],
)
.await
.unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0].namespace_id, ns);
assert_eq!(result[0].episode_id, ep);
assert_eq!(result[0].instance, "Assassin's Creed Odyssey");
assert_eq!(result[0].quantity, Some(70.0));
assert_eq!(result[0].unit.as_deref(), Some("hours"));
assert_eq!(result[1].instance, "Dune");
assert!(result[1].quantity.is_none());
}
#[tokio::test]
async fn extractor_survives_markdown_fence_wrapper() {
let server = MockServer::start().await;
let fenced = "```json\n[{\"entity_type\":\"x\",\"instance\":\"y\",\"action\":\"z\",\"confidence\":0.8}]\n```";
Mock::given(method("POST"))
.and(path("/v1/messages"))
.respond_with(
ResponseTemplate::new(200).set_body_json(anthropic_response_body(fenced)),
)
.mount(&server)
.await;
let extractor = LegacyAnthropicExtractor::new("k")
.unwrap()
.with_base_url(server.uri());
let out = extractor
.extract(Uuid::new_v4(), Uuid::new_v4(), &[])
.await
.unwrap();
assert_eq!(out.len(), 1);
assert_eq!(out[0].instance, "y");
}
#[tokio::test]
async fn extractor_returns_empty_on_unparseable_response() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/messages"))
.respond_with(
ResponseTemplate::new(200)
.set_body_json(anthropic_response_body("sorry, I cannot help with that")),
)
.mount(&server)
.await;
let extractor = LegacyAnthropicExtractor::new("k")
.unwrap()
.with_base_url(server.uri());
let out = extractor
.extract(Uuid::new_v4(), Uuid::new_v4(), &[])
.await
.unwrap();
assert!(out.is_empty());
}
#[tokio::test]
async fn extractor_surfaces_http_errors_as_transport_error() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/messages"))
.respond_with(ResponseTemplate::new(500).set_body_string("server broke"))
.mount(&server)
.await;
let extractor = LegacyAnthropicExtractor::new("k")
.unwrap()
.with_base_url(server.uri());
let err = extractor
.extract(Uuid::new_v4(), Uuid::new_v4(), &[])
.await
.unwrap_err();
assert!(matches!(err, ExtractionError::Transport(_)));
}
#[test]
fn new_rejects_when_api_key_lookup_fails() {
let err = LegacyAnthropicExtractor::new("")
.and_then(|e| {
Ok(e)
})
.err();
assert!(err.is_none(), "constructor should not validate key shape");
}
#[test]
fn from_env_error_is_config_variant() {
let e = ExtractionError::Config("missing".into());
assert!(matches!(e, ExtractionError::Config(_)));
}
#[test]
fn prompt_contains_instruction_and_memory_body() {
let msgs = [ExtractionMessage {
role: "user".into(),
content: "I played AC Odyssey".into(),
event_time: None,
}];
let prompt = LegacyAnthropicExtractor::build_prompt(&msgs);
assert!(prompt.contains("countable entity"));
assert!(prompt.contains("user: I played AC Odyssey"));
assert!(prompt.contains("--- Recalled memories ---"));
}
#[test]
fn prompt_handles_empty_messages() {
let prompt = LegacyAnthropicExtractor::build_prompt(&[]);
assert!(prompt.contains("No conversation memories provided"));
}
#[test]
fn prompt_omits_role_prefix_when_role_empty() {
let msgs = [ExtractionMessage {
role: String::new(),
content: "Check http://example.com at 10:30".to_string(),
event_time: None,
}];
let prompt = LegacyAnthropicExtractor::build_prompt(&msgs);
assert!(prompt.contains("[unknown] Check http://example.com at 10:30"));
assert!(!prompt.contains("10: 30"));
assert!(!prompt.contains("http: //"));
}
#[test]
fn parse_response_clamps_confidence() {
let raw = r#"[{"entity_type":"x","instance":"y","action":"z","confidence":1.5}]"#;
let parsed = parse_response(raw);
let obs = raw_to_observation(
parsed.into_iter().next().unwrap(),
Uuid::new_v4(),
Uuid::new_v4(),
None,
);
assert!(obs.confidence <= 1.0);
assert!(obs.confidence >= 0.0);
}
#[test]
fn content_excludes_date_prefix_event_time_in_metadata_only() {
let raw = RawObservation {
entity_type: "degree_earned".into(),
instance: "Business Administration".into(),
action: "graduated with".into(),
quantity: Some(1.0),
unit: None,
confidence: 0.9,
};
let event_time = Some(
DateTime::parse_from_rfc3339("2024-05-10T14:00:00Z")
.unwrap()
.with_timezone(&Utc),
);
let obs = raw_to_observation(raw, Uuid::new_v4(), Uuid::new_v4(), event_time);
assert_eq!(obs.content, "graduated with Business Administration (1)");
assert_eq!(obs.event_time, event_time);
}
#[test]
fn content_omits_date_when_event_time_absent() {
let raw = RawObservation {
entity_type: "task_tried".into(),
instance: "Todoist".into(),
action: "will try out".into(),
quantity: None,
unit: None,
confidence: 0.8,
};
let obs = raw_to_observation(raw, Uuid::new_v4(), Uuid::new_v4(), None);
assert_eq!(obs.content, "will try out Todoist");
assert!(obs.event_time.is_none());
}
#[tokio::test]
async fn extractor_sends_cached_system_block_by_default() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/messages"))
.and(body_partial_json(serde_json::json!({
"system": [{
"type": "text",
"text": EXTRACTION_PROMPT_V1,
"cache_control": {"type": "ephemeral"},
}],
})))
.respond_with(
ResponseTemplate::new(200).set_body_json(anthropic_response_body("[]")),
)
.mount(&server)
.await;
let extractor = LegacyAnthropicExtractor::new("test-key")
.unwrap()
.with_base_url(server.uri());
extractor
.extract(
Uuid::new_v4(),
Uuid::new_v4(),
&[ExtractionMessage {
role: "user".into(),
content: "I played AC Odyssey".into(),
event_time: None,
}],
)
.await
.unwrap();
let received = server.received_requests().await.expect("requests recorded");
assert_eq!(received.len(), 1);
let body: serde_json::Value = received[0].body_json().expect("json body");
let user_content = body["messages"][0]["content"]
.as_str()
.expect("user message content");
assert!(
user_content.contains("--- Recalled memories ---"),
"user message must carry the recalled-memories framing, got: {user_content}"
);
assert!(
!user_content.contains("structured-data extractor"),
"user message leaked instruction text: {user_content}"
);
}
#[tokio::test]
async fn extractor_omits_cache_control_when_caching_disabled() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/messages"))
.respond_with(
ResponseTemplate::new(200).set_body_json(anthropic_response_body("[]")),
)
.mount(&server)
.await;
let extractor = LegacyAnthropicExtractor::new("test-key")
.unwrap()
.with_base_url(server.uri())
.without_prompt_caching();
extractor
.extract(Uuid::new_v4(), Uuid::new_v4(), &[])
.await
.unwrap();
let received = server.received_requests().await.expect("requests recorded");
assert_eq!(received.len(), 1);
let body: serde_json::Value = received[0].body_json().expect("json body");
assert_eq!(body["system"][0]["type"], "text");
assert_eq!(body["system"][0]["text"], EXTRACTION_PROMPT_V1);
assert!(
body["system"][0].get("cache_control").is_none(),
"cache_control must be omitted when caching disabled, got: {}",
body["system"][0]
);
}
#[tokio::test]
async fn extractor_user_message_excludes_instruction_prompt() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/messages"))
.respond_with(
ResponseTemplate::new(200).set_body_json(anthropic_response_body("[]")),
)
.mount(&server)
.await;
let extractor = LegacyAnthropicExtractor::new("test-key")
.unwrap()
.with_base_url(server.uri());
extractor
.extract(
Uuid::new_v4(),
Uuid::new_v4(),
&[ExtractionMessage {
role: "user".into(),
content: "I cooked dinner three times".into(),
event_time: None,
}],
)
.await
.unwrap();
let received = server.received_requests().await.expect("requests recorded");
let body: serde_json::Value = received[0].body_json().expect("json body");
let user_content = body["messages"][0]["content"]
.as_str()
.expect("user message content");
assert!(
!user_content.contains("structured"),
"user message must not embed the system prompt header: {user_content}"
);
assert!(
user_content.contains("--- Recalled memories ---"),
"user message must contain the memory framing: {user_content}"
);
assert!(
user_content.contains("I cooked dinner three times"),
"user message must contain the supplied turn: {user_content}"
);
}
}
}
#[cfg(feature = "legacy-anthropic-extractor")]
pub use legacy_anthropic::LegacyAnthropicExtractor;
#[cfg(feature = "legacy-anthropic-extractor")]
mod legacy_batched_anthropic {
use super::legacy_anthropic::{
ANTHROPIC_VERSION, AnthropicMessage, CacheControl, LegacyAnthropicExtractor, SystemBlock,
};
use super::prompt_v1::{parse_response, raw_to_observation, system_prompt, user_message};
use super::{
ExtractionError, ExtractionMessage, ExtractionResult, ObservationExtractor,
ObservationMemory,
};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::{Duration, Instant};
use uuid::Uuid;
const DEFAULT_POLL_INTERVAL: Duration = Duration::from_secs(30);
const DEFAULT_MAX_WAIT: Duration = Duration::from_secs(2 * 60 * 60);
#[derive(Debug, Clone)]
pub struct LegacyBatchedAnthropicExtractor {
inner: LegacyAnthropicExtractor,
poll_interval: Duration,
max_wait: Duration,
}
impl LegacyBatchedAnthropicExtractor {
#[must_use]
pub fn new(inner: LegacyAnthropicExtractor) -> Self {
Self {
inner,
poll_interval: DEFAULT_POLL_INTERVAL,
max_wait: DEFAULT_MAX_WAIT,
}
}
pub fn from_env() -> ExtractionResult<Self> {
Ok(Self::new(LegacyAnthropicExtractor::from_env()?))
}
#[must_use]
pub fn with_poll_interval(mut self, d: Duration) -> Self {
self.poll_interval = d;
self
}
#[must_use]
pub fn with_max_wait(mut self, d: Duration) -> Self {
self.max_wait = d;
self
}
async fn submit_batch(
&self,
episode_ids: &[Uuid],
episodes: &[&[ExtractionMessage]],
) -> ExtractionResult<String> {
let user_messages: Vec<String> = episodes.iter().map(|ep| user_message(ep)).collect();
let cache_control = if self.inner.prompt_caching_enabled() {
Some(CacheControl {
cache_type: "ephemeral",
})
} else {
None
};
let entries: Vec<BatchEntry<'_>> = episode_ids
.iter()
.zip(user_messages.iter())
.map(|(eid, content)| BatchEntry {
custom_id: eid.to_string(),
params: BatchParams {
model: self.inner.model(),
max_tokens: self.inner.max_tokens(),
temperature: 0.0,
system: Some(vec![SystemBlock {
block_type: "text",
text: system_prompt(),
cache_control: cache_control.as_ref().map(|_| CacheControl {
cache_type: "ephemeral",
}),
}]),
messages: vec![AnthropicMessage {
role: "user",
content,
}],
},
})
.collect();
let req = BatchSubmitRequest { requests: entries };
let url = format!(
"{}/v1/messages/batches",
self.inner.base_url().trim_end_matches('/')
);
let response = self
.inner
.client()
.post(&url)
.header("x-api-key", self.inner.api_key())
.header("anthropic-version", ANTHROPIC_VERSION)
.header("content-type", "application/json")
.json(&req)
.send()
.await
.map_err(|e| ExtractionError::Transport(e.to_string()))?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(ExtractionError::Transport(format!("HTTP {status}: {body}")));
}
let parsed: BatchSubmitResponse = response
.json()
.await
.map_err(|e| ExtractionError::Parse(e.to_string()))?;
Ok(parsed.id)
}
async fn await_completion(&self, batch_id: &str) -> ExtractionResult<()> {
let start = Instant::now();
let url = format!(
"{}/v1/messages/batches/{batch_id}",
self.inner.base_url().trim_end_matches('/')
);
loop {
let response = self
.inner
.client()
.get(&url)
.header("x-api-key", self.inner.api_key())
.header("anthropic-version", ANTHROPIC_VERSION)
.send()
.await
.map_err(|e| ExtractionError::Transport(e.to_string()))?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(ExtractionError::Transport(format!("HTTP {status}: {body}")));
}
let status_body: BatchStatusResponse = response
.json()
.await
.map_err(|e| ExtractionError::Parse(e.to_string()))?;
match status_body.processing_status.as_str() {
"ended" => return Ok(()),
"canceling" | "canceled" | "expired" | "failed" => {
return Err(ExtractionError::Transport(format!(
"batch {batch_id} terminated with status {}",
status_body.processing_status
)));
}
_ => {}
}
if Instant::now().duration_since(start) >= self.max_wait {
return Err(ExtractionError::Other(format!(
"batch {batch_id} exceeded max_wait of {:?}",
self.max_wait
)));
}
tokio::time::sleep(self.poll_interval).await;
}
}
async fn collect_results(
&self,
batch_id: &str,
namespace_id: Uuid,
episode_ids: &[Uuid],
) -> ExtractionResult<Vec<Vec<ObservationMemory>>> {
let url = format!(
"{}/v1/messages/batches/{batch_id}/results",
self.inner.base_url().trim_end_matches('/')
);
let response = self
.inner
.client()
.get(&url)
.header("x-api-key", self.inner.api_key())
.header("anthropic-version", ANTHROPIC_VERSION)
.send()
.await
.map_err(|e| ExtractionError::Transport(e.to_string()))?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(ExtractionError::Transport(format!("HTTP {status}: {body}")));
}
let body = response
.text()
.await
.map_err(|e| ExtractionError::Transport(e.to_string()))?;
let mut by_custom_id: HashMap<String, Vec<ObservationMemory>> =
HashMap::with_capacity(episode_ids.len());
for line in body.lines() {
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
let parsed: BatchResultLine = match serde_json::from_str(trimmed) {
Ok(p) => p,
Err(e) => {
tracing::warn!(
target: "pensyve::observation",
error = %e,
line = %trimmed,
"skipping malformed batch result line",
);
continue;
}
};
let Some(eid) = parse_episode_id(&parsed.custom_id, episode_ids) else {
tracing::warn!(
target: "pensyve::observation",
custom_id = %parsed.custom_id,
"batch result custom_id not in input set — dropping",
);
continue;
};
match parsed.result {
BatchResultPayload::Succeeded { message } => {
let text = message
.content
.into_iter()
.find(|b| b.block_type == "text")
.map(|b| b.text)
.unwrap_or_default();
let raws = parse_response(&text);
let observations = raws
.into_iter()
.map(|r| raw_to_observation(r, namespace_id, eid, None))
.collect();
by_custom_id.insert(parsed.custom_id, observations);
}
BatchResultPayload::Errored { error }
| BatchResultPayload::Canceled { error }
| BatchResultPayload::Expired { error } => {
tracing::warn!(
target: "pensyve::observation",
custom_id = %parsed.custom_id,
error = ?error,
"batch entry failed — emitting empty observations for this episode",
);
by_custom_id.insert(parsed.custom_id, Vec::new());
}
}
}
let out = episode_ids
.iter()
.map(|eid| {
by_custom_id.remove(&eid.to_string()).unwrap_or_else(|| {
tracing::warn!(
target: "pensyve::observation",
episode_id = %eid,
"no batch result for episode — emitting empty observations",
);
Vec::new()
})
})
.collect();
Ok(out)
}
}
fn parse_episode_id(custom_id: &str, episode_ids: &[Uuid]) -> Option<Uuid> {
let parsed = Uuid::parse_str(custom_id).ok()?;
episode_ids.iter().find(|eid| **eid == parsed).copied()
}
#[derive(Debug, Serialize)]
struct BatchSubmitRequest<'a> {
requests: Vec<BatchEntry<'a>>,
}
#[derive(Debug, Serialize)]
struct BatchEntry<'a> {
custom_id: String,
params: BatchParams<'a>,
}
#[derive(Debug, Serialize)]
struct BatchParams<'a> {
model: &'a str,
max_tokens: u32,
temperature: f32,
#[serde(skip_serializing_if = "Option::is_none")]
system: Option<Vec<SystemBlock<'a>>>,
messages: Vec<AnthropicMessage<'a>>,
}
#[derive(Debug, Deserialize)]
struct BatchSubmitResponse {
id: String,
}
#[derive(Debug, Deserialize)]
struct BatchStatusResponse {
processing_status: String,
}
#[derive(Debug, Deserialize)]
struct BatchResultLine {
custom_id: String,
result: BatchResultPayload,
}
#[derive(Debug, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
enum BatchResultPayload {
Succeeded {
message: BatchResultMessage,
},
Errored {
#[serde(default)]
error: serde_json::Value,
},
Canceled {
#[serde(default)]
error: serde_json::Value,
},
Expired {
#[serde(default)]
error: serde_json::Value,
},
}
#[derive(Debug, Deserialize)]
struct BatchResultMessage {
#[serde(default)]
content: Vec<BatchResultContentBlock>,
}
#[derive(Debug, Deserialize)]
struct BatchResultContentBlock {
#[serde(rename = "type")]
block_type: String,
#[serde(default)]
text: String,
}
#[async_trait]
impl ObservationExtractor for LegacyBatchedAnthropicExtractor {
async fn extract(
&self,
namespace_id: Uuid,
episode_id: Uuid,
messages: &[ExtractionMessage],
) -> ExtractionResult<Vec<ObservationMemory>> {
self.inner.extract(namespace_id, episode_id, messages).await
}
async fn extract_batch(
&self,
namespace_id: Uuid,
episode_ids: &[Uuid],
episodes: Vec<&[ExtractionMessage]>,
) -> ExtractionResult<Vec<Vec<ObservationMemory>>> {
if episode_ids.len() != episodes.len() {
return Err(ExtractionError::Other(format!(
"extract_batch: length mismatch ({} ids vs {} episodes)",
episode_ids.len(),
episodes.len(),
)));
}
if episodes.is_empty() {
return Ok(Vec::new());
}
let batch_id = self.submit_batch(episode_ids, &episodes).await?;
self.await_completion(&batch_id).await?;
self.collect_results(&batch_id, namespace_id, episode_ids)
.await
}
}
#[cfg(test)]
#[allow(
clippy::too_many_lines,
reason = "test code: each wiremock scenario sets up its own fixture inline for readability"
)]
mod tests {
use super::*;
use wiremock::matchers::{method, path, path_regex};
use wiremock::{Mock, MockServer, ResponseTemplate};
fn batch_submit_response(id: &str) -> serde_json::Value {
serde_json::json!({
"id": id,
"type": "message_batch",
"processing_status": "in_progress",
"request_counts": {"processing": 0, "succeeded": 0, "errored": 0, "canceled": 0, "expired": 0},
})
}
fn status_response(processing_status: &str) -> serde_json::Value {
serde_json::json!({
"processing_status": processing_status,
})
}
fn jsonl_succeeded(custom_id: &str, text: &str) -> String {
let line = serde_json::json!({
"custom_id": custom_id,
"result": {
"type": "succeeded",
"message": {
"id": "msg_test",
"type": "message",
"role": "assistant",
"model": "claude-haiku-4-5-20251001",
"content": [{"type": "text", "text": text}],
"stop_reason": "end_turn",
"usage": {"input_tokens": 0, "output_tokens": 0},
},
},
});
line.to_string()
}
fn jsonl_errored(custom_id: &str) -> String {
let line = serde_json::json!({
"custom_id": custom_id,
"result": {
"type": "errored",
"error": {"type": "overloaded_error", "message": "slow down"},
},
});
line.to_string()
}
fn make_extractor(server_uri: &str) -> LegacyBatchedAnthropicExtractor {
let inner = LegacyAnthropicExtractor::new("test-key")
.unwrap()
.with_base_url(server_uri.to_string());
LegacyBatchedAnthropicExtractor::new(inner)
.with_poll_interval(Duration::from_millis(10))
.with_max_wait(Duration::from_secs(5))
}
fn msg(text: &str) -> ExtractionMessage {
ExtractionMessage {
role: "user".into(),
content: text.into(),
event_time: None,
}
}
#[tokio::test]
async fn batch_submit_posts_one_entry_per_episode() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/messages/batches"))
.respond_with(
ResponseTemplate::new(200)
.set_body_json(batch_submit_response("msgbatch_test123")),
)
.expect(1)
.mount(&server)
.await;
Mock::given(method("GET"))
.and(path_regex(r"^/v1/messages/batches/msgbatch_[a-zA-Z0-9_]+$"))
.respond_with(ResponseTemplate::new(200).set_body_json(status_response("ended")))
.mount(&server)
.await;
Mock::given(method("GET"))
.and(path_regex(r"^/v1/messages/batches/.+/results$"))
.respond_with(ResponseTemplate::new(200).set_body_string(""))
.mount(&server)
.await;
let extractor = make_extractor(&server.uri());
let ns = Uuid::new_v4();
let ids = [Uuid::new_v4(), Uuid::new_v4(), Uuid::new_v4()];
let ep0 = [msg("ep0 turn 0"), msg("ep0 turn 1")];
let ep1 = [msg("ep1 turn 0"), msg("ep1 turn 1")];
let ep2 = [msg("ep2 turn 0"), msg("ep2 turn 1")];
let episodes: Vec<&[ExtractionMessage]> = vec![&ep0, &ep1, &ep2];
extractor
.extract_batch(ns, &ids, episodes)
.await
.expect("extract_batch ok");
let received = server.received_requests().await.expect("requests recorded");
let submit = received
.iter()
.find(|r| {
r.method == wiremock::http::Method::POST
&& r.url.path() == "/v1/messages/batches"
})
.expect("submit POST captured");
let body: serde_json::Value = submit.body_json().expect("submit body json");
let requests = body["requests"].as_array().expect("requests array");
assert_eq!(requests.len(), 3);
for (i, entry) in requests.iter().enumerate() {
assert_eq!(entry["custom_id"].as_str().unwrap(), ids[i].to_string());
let params = &entry["params"];
assert_eq!(
params["model"].as_str().unwrap(),
"claude-haiku-4-5-20251001"
);
assert_eq!(params["temperature"], serde_json::json!(0.0));
let system = params["system"].as_array().expect("system array");
assert_eq!(system[0]["type"].as_str().unwrap(), "text");
assert_eq!(
system[0]["cache_control"]["type"].as_str().unwrap(),
"ephemeral",
"cache_control must be ephemeral by default",
);
let user_content = params["messages"][0]["content"].as_str().unwrap();
assert!(
user_content.contains("--- Recalled memories ---"),
"user message must carry the recalled-memories framing, got: {user_content}"
);
}
}
#[tokio::test]
async fn batch_polls_until_ended() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/messages/batches"))
.respond_with(
ResponseTemplate::new(200)
.set_body_json(batch_submit_response("msgbatch_polltest")),
)
.mount(&server)
.await;
Mock::given(method("GET"))
.and(path_regex(r"^/v1/messages/batches/msgbatch_[a-zA-Z0-9_]+$"))
.respond_with(
ResponseTemplate::new(200).set_body_json(status_response("in_progress")),
)
.up_to_n_times(1)
.mount(&server)
.await;
Mock::given(method("GET"))
.and(path_regex(r"^/v1/messages/batches/msgbatch_[a-zA-Z0-9_]+$"))
.respond_with(ResponseTemplate::new(200).set_body_json(status_response("ended")))
.mount(&server)
.await;
let ids = vec![Uuid::new_v4(), Uuid::new_v4()];
let body = format!(
"{}\n{}\n",
jsonl_succeeded(&ids[0].to_string(), "[]"),
jsonl_succeeded(&ids[1].to_string(), "[]"),
);
Mock::given(method("GET"))
.and(path_regex(r"^/v1/messages/batches/.+/results$"))
.respond_with(ResponseTemplate::new(200).set_body_string(body))
.mount(&server)
.await;
let extractor = make_extractor(&server.uri());
let ns = Uuid::new_v4();
let ep0 = [msg("ep0")];
let ep1 = [msg("ep1")];
let episodes: Vec<&[ExtractionMessage]> = vec![&ep0, &ep1];
let out = extractor
.extract_batch(ns, &ids, episodes)
.await
.expect("extract_batch should poll through in_progress to ended");
assert_eq!(out.len(), 2);
}
#[tokio::test]
async fn batch_collects_results_routed_by_custom_id() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/messages/batches"))
.respond_with(
ResponseTemplate::new(200)
.set_body_json(batch_submit_response("msgbatch_routetest")),
)
.mount(&server)
.await;
Mock::given(method("GET"))
.and(path_regex(r"^/v1/messages/batches/msgbatch_[a-zA-Z0-9_]+$"))
.respond_with(ResponseTemplate::new(200).set_body_json(status_response("ended")))
.mount(&server)
.await;
let ids = vec![Uuid::new_v4(), Uuid::new_v4(), Uuid::new_v4()];
let payload_2 = serde_json::to_string(&serde_json::json!([{
"entity_type": "game_played", "instance": "Tetris", "action": "played", "confidence": 0.9
}])).unwrap();
let payload_0 = serde_json::to_string(&serde_json::json!([{
"entity_type": "book_read", "instance": "Dune", "action": "read", "confidence": 0.9
}]))
.unwrap();
let body = format!(
"{}\n{}\n{}\n",
jsonl_succeeded(&ids[2].to_string(), &payload_2),
jsonl_succeeded(&ids[1].to_string(), "[]"),
jsonl_succeeded(&ids[0].to_string(), &payload_0),
);
Mock::given(method("GET"))
.and(path_regex(r"^/v1/messages/batches/.+/results$"))
.respond_with(ResponseTemplate::new(200).set_body_string(body))
.mount(&server)
.await;
let extractor = make_extractor(&server.uri());
let ns = Uuid::new_v4();
let ep0 = [msg("ep0")];
let ep1 = [msg("ep1")];
let ep2 = [msg("ep2")];
let episodes: Vec<&[ExtractionMessage]> = vec![&ep0, &ep1, &ep2];
let out = extractor
.extract_batch(ns, &ids, episodes)
.await
.expect("extract_batch ok");
assert_eq!(out.len(), 3);
assert_eq!(out[0].len(), 1);
assert_eq!(out[0][0].instance, "Dune");
assert_eq!(out[0][0].episode_id, ids[0]);
assert!(out[1].is_empty());
assert_eq!(out[2].len(), 1);
assert_eq!(out[2][0].instance, "Tetris");
assert_eq!(out[2][0].episode_id, ids[2]);
}
#[tokio::test]
async fn batch_per_entry_error_emits_empty_observations() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/messages/batches"))
.respond_with(
ResponseTemplate::new(200)
.set_body_json(batch_submit_response("msgbatch_errortest")),
)
.mount(&server)
.await;
Mock::given(method("GET"))
.and(path_regex(r"^/v1/messages/batches/msgbatch_[a-zA-Z0-9_]+$"))
.respond_with(ResponseTemplate::new(200).set_body_json(status_response("ended")))
.mount(&server)
.await;
let ids = vec![Uuid::new_v4(), Uuid::new_v4()];
let payload = serde_json::to_string(&serde_json::json!([{
"entity_type": "game_played",
"instance": "Solitaire",
"action": "played",
"confidence": 0.9,
}]))
.unwrap();
let body = format!(
"{}\n{}\n",
jsonl_errored(&ids[0].to_string()),
jsonl_succeeded(&ids[1].to_string(), &payload),
);
Mock::given(method("GET"))
.and(path_regex(r"^/v1/messages/batches/.+/results$"))
.respond_with(ResponseTemplate::new(200).set_body_string(body))
.mount(&server)
.await;
let extractor = make_extractor(&server.uri());
let ns = Uuid::new_v4();
let ep0 = [msg("ep0")];
let ep1 = [msg("ep1")];
let episodes: Vec<&[ExtractionMessage]> = vec![&ep0, &ep1];
let out = extractor
.extract_batch(ns, &ids, episodes)
.await
.expect("per-entry errors must not fail the outer call");
assert_eq!(out.len(), 2);
assert!(
out[0].is_empty(),
"errored entry must yield empty observations"
);
assert_eq!(out[1].len(), 1);
assert_eq!(out[1][0].instance, "Solitaire");
}
#[tokio::test]
async fn batch_rejects_length_mismatch() {
let server = MockServer::start().await;
let extractor = make_extractor(&server.uri());
let ns = Uuid::new_v4();
let ids = [Uuid::new_v4(), Uuid::new_v4()];
let a = [msg("a")];
let b = [msg("b")];
let c = [msg("c")];
let episodes: Vec<&[ExtractionMessage]> = vec![&a, &b, &c];
let err = extractor
.extract_batch(ns, &ids, episodes)
.await
.expect_err("length mismatch must error");
match err {
ExtractionError::Other(msg) => assert!(msg.contains("length mismatch")),
other => panic!("expected Other, got {other:?}"),
}
}
#[tokio::test]
async fn batch_returns_empty_for_zero_episodes() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/messages/batches"))
.respond_with(ResponseTemplate::new(500))
.expect(0)
.mount(&server)
.await;
let extractor = make_extractor(&server.uri());
let out = extractor
.extract_batch(Uuid::new_v4(), &[], Vec::new())
.await
.expect("zero-episode call must succeed");
assert!(out.is_empty());
}
#[tokio::test]
async fn prewarm_then_cached_extract_makes_exactly_one_batch_post() {
use crate::observation::NoopExtractor;
use crate::observation::cached_bulk::{CachedBulkExtractor, fingerprint_messages};
use std::collections::HashMap;
use std::sync::Arc;
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/messages/batches"))
.respond_with(
ResponseTemplate::new(200)
.set_body_json(batch_submit_response("msgbatch_phaseCtwo")),
)
.expect(1)
.mount(&server)
.await;
Mock::given(method("GET"))
.and(path_regex(r"^/v1/messages/batches/msgbatch_[a-zA-Z0-9_]+$"))
.respond_with(ResponseTemplate::new(200).set_body_json(status_response("ended")))
.mount(&server)
.await;
let ns = Uuid::new_v4();
let ids = [Uuid::new_v4(), Uuid::new_v4(), Uuid::new_v4()];
let ep0 = [msg("user: I played AC Odyssey")];
let ep1 = [msg("user: I read Dune")];
let ep2 = [msg("user: I cooked tacos")];
let episodes: Vec<&[ExtractionMessage]> = vec![&ep0, &ep1, &ep2];
let result_body = format!(
"{}\n{}\n{}",
jsonl_succeeded(
&ids[0].to_string(),
r#"[{"entity_type":"game_played","instance":"AC Odyssey","action":"played","quantity":null,"unit":null,"confidence":0.95}]"#,
),
jsonl_succeeded(
&ids[1].to_string(),
r#"[{"entity_type":"book_read","instance":"Dune","action":"read","quantity":null,"unit":null,"confidence":0.9}]"#,
),
jsonl_succeeded(
&ids[2].to_string(),
r#"[{"entity_type":"meal_cooked","instance":"tacos","action":"cooked","quantity":null,"unit":null,"confidence":0.85}]"#,
),
);
Mock::given(method("GET"))
.and(path_regex(r"^/v1/messages/batches/.+/results$"))
.respond_with(ResponseTemplate::new(200).set_body_string(result_body))
.mount(&server)
.await;
let batch_extractor = make_extractor(&server.uri());
let batched_results = batch_extractor
.extract_batch(ns, &ids, episodes.clone())
.await
.expect("prewarm extract_batch ok");
assert_eq!(batched_results.len(), 3);
let mut cache: HashMap<u64, Vec<ObservationMemory>> = HashMap::new();
for (msgs_slice, observations) in episodes.iter().zip(batched_results.into_iter()) {
let fp = fingerprint_messages(msgs_slice);
cache.insert(fp, observations);
}
let cached = CachedBulkExtractor::new(cache, Arc::new(NoopExtractor));
for msgs_slice in &episodes {
let live_ep = Uuid::new_v4();
let live_ns = Uuid::new_v4();
let out = cached
.extract(live_ns, live_ep, msgs_slice)
.await
.expect("cache hit ok");
assert_eq!(out.len(), 1);
assert_eq!(out[0].namespace_id, live_ns);
assert_eq!(out[0].episode_id, live_ep);
}
let received = server.received_requests().await.expect("requests recorded");
let post_count = received
.iter()
.filter(|r| {
r.method == wiremock::http::Method::POST
&& r.url.path() == "/v1/messages/batches"
})
.count();
assert_eq!(
post_count, 1,
"must POST exactly once for the entire wave (got {post_count}); cache served the per-episode extracts",
);
}
}
}
#[cfg(feature = "legacy-anthropic-extractor")]
pub use legacy_batched_anthropic::LegacyBatchedAnthropicExtractor;
#[cfg(feature = "observation-extraction")]
mod cached_bulk {
use super::{ExtractionMessage, ExtractionResult, ObservationExtractor, ObservationMemory};
use async_trait::async_trait;
use std::collections::HashMap;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use uuid::Uuid;
#[must_use]
pub fn fingerprint_messages(messages: &[ExtractionMessage]) -> u64 {
let mut hasher = DefaultHasher::new();
messages.len().hash(&mut hasher);
for m in messages {
m.role.hash(&mut hasher);
m.content.hash(&mut hasher);
match m.event_time {
Some(t) => {
1_u8.hash(&mut hasher);
t.timestamp_nanos_opt().unwrap_or(0).hash(&mut hasher);
}
None => {
0_u8.hash(&mut hasher);
}
}
}
hasher.finish()
}
#[derive(Debug, Clone)]
pub struct CachedBulkExtractor {
cache: Arc<HashMap<u64, Vec<ObservationMemory>>>,
fallback: Arc<dyn ObservationExtractor>,
}
impl CachedBulkExtractor {
#[must_use]
pub fn new(
cache: HashMap<u64, Vec<ObservationMemory>>,
fallback: Arc<dyn ObservationExtractor>,
) -> Self {
Self {
cache: Arc::new(cache),
fallback,
}
}
#[must_use]
pub fn len(&self) -> usize {
self.cache.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.cache.is_empty()
}
#[must_use]
pub fn contains(&self, fingerprint: u64) -> bool {
self.cache.contains_key(&fingerprint)
}
}
#[async_trait]
impl ObservationExtractor for CachedBulkExtractor {
async fn extract(
&self,
namespace_id: Uuid,
episode_id: Uuid,
messages: &[ExtractionMessage],
) -> ExtractionResult<Vec<ObservationMemory>> {
let fp = fingerprint_messages(messages);
if let Some(cached) = self.cache.get(&fp) {
let rebound: Vec<ObservationMemory> = cached
.iter()
.map(|obs| {
let mut clone = obs.clone();
clone.namespace_id = namespace_id;
clone.episode_id = episode_id;
clone
})
.collect();
return Ok(rebound);
}
tracing::warn!(
target: "pensyve::observation",
episode_id = %episode_id,
fingerprint = fp,
"CachedBulkExtractor cache miss — falling through to inner extractor",
);
self.fallback
.extract(namespace_id, episode_id, messages)
.await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::observation::{ExtractionError, NoopExtractor};
use chrono::{TimeZone, Utc};
use std::sync::Mutex;
fn make_msgs(content: &str) -> Vec<ExtractionMessage> {
vec![ExtractionMessage {
role: "user".into(),
content: content.into(),
event_time: Some(Utc.with_ymd_and_hms(2026, 4, 27, 12, 0, 0).unwrap()),
}]
}
fn make_obs(ns: Uuid, ep: Uuid, instance: &str) -> ObservationMemory {
ObservationMemory::new(ns, ep, "game_played", instance, "played", instance)
}
#[tokio::test]
async fn cache_hit_serves_from_prewarmed_payload_and_rebinds_ids() {
let prewarm_ns = Uuid::new_v4();
let prewarm_ep = Uuid::new_v4();
let live_ns = Uuid::new_v4();
let live_ep = Uuid::new_v4();
let msgs = make_msgs("I played AC Odyssey for 70 hours");
let fp = fingerprint_messages(&msgs);
let mut cache = HashMap::new();
cache.insert(fp, vec![make_obs(prewarm_ns, prewarm_ep, "AC Odyssey")]);
let fallback = Arc::new(TrackingFallback::default());
let extractor = CachedBulkExtractor::new(cache, fallback.clone());
let out = extractor
.extract(live_ns, live_ep, &msgs)
.await
.expect("cache hit returns ok");
assert_eq!(out.len(), 1);
assert_eq!(out[0].instance, "AC Odyssey");
assert_eq!(out[0].namespace_id, live_ns);
assert_eq!(out[0].episode_id, live_ep);
assert_eq!(
fallback.calls(),
0,
"fallback must NOT fire on a cache hit (otherwise the bulk discount is wasted)",
);
}
#[tokio::test]
async fn cache_miss_falls_through_to_inner_extractor() {
let cache: HashMap<u64, Vec<ObservationMemory>> = HashMap::new();
let fallback = Arc::new(TrackingFallback::default());
let extractor = CachedBulkExtractor::new(cache, fallback.clone());
let msgs = make_msgs("never seen by the prewarm pass");
let ns = Uuid::new_v4();
let ep = Uuid::new_v4();
let out = extractor.extract(ns, ep, &msgs).await.expect("ok");
assert!(out.is_empty(), "TrackingFallback returns empty");
assert_eq!(
fallback.calls(),
1,
"fallback must fire exactly once on a miss"
);
}
#[tokio::test]
async fn fingerprint_collisions_not_observed_for_distinct_content() {
let a = make_msgs("user: I played AC Odyssey");
let b = make_msgs("user: I played Dune");
assert_ne!(fingerprint_messages(&a), fingerprint_messages(&b));
}
#[tokio::test]
async fn fingerprint_stable_across_calls() {
let msgs = make_msgs("hello");
let fp1 = fingerprint_messages(&msgs);
let fp2 = fingerprint_messages(&msgs);
assert_eq!(fp1, fp2);
}
#[tokio::test]
async fn empty_cache_is_diagnostic_only_not_an_error() {
let extractor = CachedBulkExtractor::new(HashMap::new(), Arc::new(NoopExtractor));
assert!(extractor.is_empty());
assert_eq!(extractor.len(), 0);
assert!(!extractor.contains(0));
}
#[derive(Debug, Default)]
struct TrackingFallback {
calls: Mutex<usize>,
}
impl TrackingFallback {
fn calls(&self) -> usize {
*self.calls.lock().unwrap()
}
}
#[async_trait]
impl ObservationExtractor for TrackingFallback {
async fn extract(
&self,
_namespace_id: Uuid,
_episode_id: Uuid,
_messages: &[ExtractionMessage],
) -> ExtractionResult<Vec<ObservationMemory>> {
*self.calls.lock().unwrap() += 1;
Ok(Vec::new())
}
}
#[allow(dead_code)]
fn cached_bulk_is_object_safe() {
fn takes_dyn(_: &dyn ObservationExtractor) {}
let cb = CachedBulkExtractor::new(HashMap::new(), Arc::new(NoopExtractor));
takes_dyn(&cb);
}
#[allow(dead_code)]
fn _error_in_scope() -> Option<ExtractionError> {
None
}
}
}
#[cfg(feature = "observation-extraction")]
pub use cached_bulk::{CachedBulkExtractor, fingerprint_messages};
#[cfg(feature = "observation-extraction")]
mod localllm {
use super::prompt_v1::{self, RawObservation, parse_response, raw_to_observation};
use super::{
ExtractionError, ExtractionMessage, ExtractionResult, ObservationExtractor,
ObservationMemory,
};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::time::Duration;
use uuid::Uuid;
const DEFAULT_BASE_URL: &str = "http://localhost:8888/v1";
const DEFAULT_MODEL: &str = "qwen3.6-35b-a3b";
const DEFAULT_MAX_TOKENS: u32 = 4096;
const DEFAULT_TIMEOUT_SECS: u64 = 300;
#[derive(Debug, Clone)]
pub struct LocalLLMExtractor {
client: reqwest::Client,
base_url: String,
model: String,
api_key: Option<String>,
max_tokens: u32,
}
impl LocalLLMExtractor {
pub fn new(
base_url: impl Into<String>,
model: impl Into<String>,
api_key: Option<String>,
) -> ExtractionResult<Self> {
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECS))
.build()
.map_err(|e| ExtractionError::Config(format!("http client build: {e}")))?;
Ok(Self {
client,
base_url: base_url.into(),
model: model.into(),
api_key,
max_tokens: DEFAULT_MAX_TOKENS,
})
}
pub fn from_env() -> ExtractionResult<Self> {
let base_url =
std::env::var("PENSYVE_EXTRACTOR_URL").unwrap_or_else(|_| DEFAULT_BASE_URL.into());
let model =
std::env::var("PENSYVE_EXTRACTOR_MODEL").unwrap_or_else(|_| DEFAULT_MODEL.into());
let api_key = std::env::var("PENSYVE_EXTRACTOR_API_KEY").ok();
Self::new(base_url, model, api_key)
}
#[must_use]
pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
self.base_url = base_url.into();
self
}
#[must_use]
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.model = model.into();
self
}
#[must_use]
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = max_tokens;
self
}
fn build_prompt(messages: &[ExtractionMessage]) -> String {
prompt_v1::build_prompt(messages)
}
}
#[derive(Debug, Serialize)]
struct OpenAIRequest<'a> {
model: &'a str,
messages: Vec<OpenAIMessage<'a>>,
max_tokens: u32,
temperature: f32,
chat_template_kwargs: ChatTemplateKwargs,
}
#[derive(Debug, Serialize)]
struct ChatTemplateKwargs {
enable_thinking: bool,
}
#[derive(Debug, Serialize)]
struct OpenAIMessage<'a> {
role: &'a str,
content: &'a str,
}
#[derive(Debug, Deserialize)]
struct OpenAIResponse {
#[serde(default)]
choices: Vec<OpenAIChoice>,
}
#[derive(Debug, Deserialize)]
struct OpenAIChoice {
#[serde(default)]
message: OpenAIChoiceMessage,
}
#[derive(Debug, Deserialize, Default)]
struct OpenAIChoiceMessage {
#[serde(default)]
content: String,
}
#[async_trait]
impl ObservationExtractor for LocalLLMExtractor {
async fn extract(
&self,
namespace_id: Uuid,
episode_id: Uuid,
messages: &[ExtractionMessage],
) -> ExtractionResult<Vec<ObservationMemory>> {
let prompt = Self::build_prompt(messages);
let last_event_time = messages.iter().filter_map(|m| m.event_time).max();
let req = OpenAIRequest {
model: &self.model,
messages: vec![OpenAIMessage {
role: "user",
content: &prompt,
}],
max_tokens: self.max_tokens,
temperature: 0.0,
chat_template_kwargs: ChatTemplateKwargs {
enable_thinking: false,
},
};
let base = self.base_url.trim_end_matches('/');
let base = if base.ends_with("/v1") {
base.to_string()
} else {
format!("{base}/v1")
};
let url = format!("{base}/chat/completions");
let mut builder = self.client.post(&url).json(&req);
if let Some(key) = self.api_key.as_deref() {
builder = builder.bearer_auth(key);
}
let response = builder
.send()
.await
.map_err(|e| ExtractionError::Transport(e.to_string()))?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(ExtractionError::Transport(format!("HTTP {status}: {body}")));
}
let parsed: OpenAIResponse = response
.json()
.await
.map_err(|e| ExtractionError::Parse(e.to_string()))?;
let text = parsed
.choices
.into_iter()
.next()
.map(|c| c.message.content)
.unwrap_or_default();
let raws: Vec<RawObservation> = parse_response(&text);
Ok(raws
.into_iter()
.map(|r| raw_to_observation(r, namespace_id, episode_id, last_event_time))
.collect())
}
}
#[cfg(test)]
#[allow(
clippy::err_expect,
reason = "test code: `.err().expect()` mirrors the structure of preceding ok-path asserts"
)]
mod tests {
use super::*;
use chrono::{DateTime, Utc};
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
fn openai_response_body(text: &str) -> serde_json::Value {
serde_json::json!({
"id": "chatcmpl-test",
"object": "chat.completion",
"model": "local",
"choices": [{
"index": 0,
"message": {"role": "assistant", "content": text},
"finish_reason": "stop",
}],
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
})
}
#[test]
fn from_env_uses_defaults_when_unset() {
let e = LocalLLMExtractor::from_env().expect("from_env");
assert!(!e.base_url.is_empty());
assert!(!e.model.is_empty());
}
#[test]
fn build_prompt_date_anchors_turn_bodies() {
let msgs = [ExtractionMessage {
role: "user".into(),
content: "I picked up boots from Zara.".into(),
event_time: DateTime::parse_from_rfc3339("2024-02-05T10:00:00Z")
.ok()
.map(|d| d.with_timezone(&Utc)),
}];
let p = LocalLLMExtractor::build_prompt(&msgs);
assert!(p.contains("[2024-02-05] user: I picked up boots from Zara."));
assert!(p.contains("--- Recalled memories ---"));
}
#[tokio::test]
async fn extractor_parses_openai_shaped_response() {
let server = MockServer::start().await;
let raw_json = r#"[{"entity_type":"degree_earned","instance":"Business Administration","action":"graduated with","quantity":1,"confidence":0.9}]"#;
Mock::given(method("POST"))
.and(path("/v1/chat/completions"))
.respond_with(
ResponseTemplate::new(200).set_body_json(openai_response_body(raw_json)),
)
.expect(1)
.mount(&server)
.await;
let extractor = LocalLLMExtractor::new(server.uri(), "local", None).unwrap();
let event_time = DateTime::parse_from_rfc3339("2024-05-10T14:00:00Z")
.ok()
.map(|d| d.with_timezone(&Utc));
let msgs = [ExtractionMessage {
role: String::new(),
content: "I graduated with a BS in BA.".into(),
event_time,
}];
let out = extractor
.extract(Uuid::new_v4(), Uuid::new_v4(), &msgs)
.await
.expect("ok");
assert_eq!(out.len(), 1);
assert_eq!(out[0].content, "graduated with Business Administration (1)");
assert_eq!(out[0].event_time, event_time);
}
#[tokio::test]
async fn extractor_surfaces_http_errors_as_transport_error() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/chat/completions"))
.respond_with(ResponseTemplate::new(500).set_body_string("boom"))
.expect(1)
.mount(&server)
.await;
let extractor = LocalLLMExtractor::new(server.uri(), "local", None).unwrap();
let err = extractor
.extract(Uuid::new_v4(), Uuid::new_v4(), &[])
.await
.err()
.expect("err");
match err {
ExtractionError::Transport(_) => {}
other => panic!("expected Transport, got {other:?}"),
}
}
#[tokio::test]
async fn extractor_returns_empty_on_unparseable_response() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/chat/completions"))
.respond_with(
ResponseTemplate::new(200)
.set_body_json(openai_response_body("I'm sorry, I cannot comply.")),
)
.mount(&server)
.await;
let extractor = LocalLLMExtractor::new(server.uri(), "local", None).unwrap();
let out = extractor
.extract(Uuid::new_v4(), Uuid::new_v4(), &[])
.await
.expect("ok");
assert!(out.is_empty());
}
#[tokio::test]
async fn base_url_without_v1_suffix_is_normalized() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/chat/completions"))
.respond_with(ResponseTemplate::new(200).set_body_json(openai_response_body("[]")))
.expect(1)
.mount(&server)
.await;
let bare = server.uri(); let extractor = LocalLLMExtractor::new(bare, "local", None).unwrap();
extractor
.extract(Uuid::new_v4(), Uuid::new_v4(), &[])
.await
.expect("ok");
}
#[test]
fn default_config_matches_spec_table() {
assert_eq!(DEFAULT_BASE_URL, "http://localhost:8888/v1");
assert_eq!(DEFAULT_MODEL, "qwen3.6-35b-a3b");
assert_eq!(DEFAULT_MAX_TOKENS, 4096);
}
#[test]
fn builders_chain_and_override_defaults() {
let extractor = LocalLLMExtractor::new("http://example.com/v1", "default-model", None)
.expect("new")
.with_base_url("http://override.test/v1")
.with_model("qwen3.6-35b-a3b")
.with_max_tokens(2048);
assert_eq!(extractor.base_url, "http://override.test/v1");
assert_eq!(extractor.model, "qwen3.6-35b-a3b");
assert_eq!(extractor.max_tokens, 2048);
assert!(extractor.api_key.is_none());
}
#[tokio::test]
async fn request_body_matches_openai_chat_completions_shape() {
let server = MockServer::start().await;
let expected_body = serde_json::json!({
"model": "qwen3.6-35b-a3b",
"temperature": 0.0,
"max_tokens": 4096,
"chat_template_kwargs": {"enable_thinking": false},
});
Mock::given(method("POST"))
.and(path("/v1/chat/completions"))
.and(wiremock::matchers::body_partial_json(expected_body))
.respond_with(ResponseTemplate::new(200).set_body_json(openai_response_body("[]")))
.expect(1)
.mount(&server)
.await;
let extractor = LocalLLMExtractor::new(server.uri(), "qwen3.6-35b-a3b", None).unwrap();
let msgs = [ExtractionMessage {
role: String::new(),
content: "I bought 2 books today.".into(),
event_time: None,
}];
extractor
.extract(Uuid::new_v4(), Uuid::new_v4(), &msgs)
.await
.expect("ok");
}
#[tokio::test]
async fn request_user_message_carries_extraction_prompt_v1() {
let server = MockServer::start().await;
let captured: std::sync::Arc<std::sync::Mutex<Option<serde_json::Value>>> =
std::sync::Arc::new(std::sync::Mutex::new(None));
let cap = captured.clone();
Mock::given(method("POST"))
.and(path("/v1/chat/completions"))
.respond_with(move |req: &wiremock::Request| {
if let Ok(v) = serde_json::from_slice::<serde_json::Value>(&req.body)
&& let Ok(mut g) = cap.lock()
{
*g = Some(v);
}
ResponseTemplate::new(200).set_body_json(openai_response_body("[]"))
})
.expect(1)
.mount(&server)
.await;
let extractor = LocalLLMExtractor::new(server.uri(), "qwen3.6-35b-a3b", None).unwrap();
let msgs = [ExtractionMessage {
role: String::new(),
content: "I bought 2 books today.".into(),
event_time: None,
}];
extractor
.extract(Uuid::new_v4(), Uuid::new_v4(), &msgs)
.await
.expect("ok");
let body = captured
.lock()
.ok()
.and_then(|g| g.clone())
.expect("captured body");
let content = body["messages"][0]["content"]
.as_str()
.expect("user message content");
assert!(content.contains("structured-data extractor"));
assert!(content.contains("--- Recalled memories ---"));
assert!(content.contains("I bought 2 books today."));
assert_eq!(body["messages"][0]["role"].as_str(), Some("user"));
}
#[tokio::test]
async fn extractor_with_short_timeout_surfaces_transport_error() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/chat/completions"))
.respond_with(
ResponseTemplate::new(200)
.set_body_json(openai_response_body("[]"))
.set_delay(Duration::from_millis(500)),
)
.mount(&server)
.await;
let client = reqwest::Client::builder()
.timeout(Duration::from_millis(50))
.build()
.expect("client");
let extractor = LocalLLMExtractor {
client,
base_url: server.uri(),
model: "qwen3.6-35b-a3b".into(),
api_key: None,
max_tokens: DEFAULT_MAX_TOKENS,
};
let err = extractor
.extract(Uuid::new_v4(), Uuid::new_v4(), &[])
.await
.err()
.expect("err");
match err {
ExtractionError::Transport(_) => {}
other => panic!("expected Transport, got {other:?}"),
}
}
}
}
#[cfg(feature = "observation-extraction")]
pub use localllm::LocalLLMExtractor;
#[cfg(feature = "observation-extraction")]
mod batched_localllm {
use super::localllm::LocalLLMExtractor;
use super::{
ExtractionError, ExtractionMessage, ExtractionResult, ObservationExtractor,
ObservationMemory,
};
use async_trait::async_trait;
use std::sync::Arc;
use tokio::sync::Semaphore;
use uuid::Uuid;
#[derive(Debug, Clone)]
pub struct BatchedLocalLLMExtractor {
inner: LocalLLMExtractor,
max_concurrency: usize,
}
impl BatchedLocalLLMExtractor {
pub const DEFAULT_MAX_CONCURRENCY: usize = 4;
#[must_use]
pub fn new(inner: LocalLLMExtractor) -> Self {
Self {
inner,
max_concurrency: Self::DEFAULT_MAX_CONCURRENCY,
}
}
#[must_use]
pub fn with_max_concurrency(mut self, n: usize) -> Self {
self.max_concurrency = n.max(1);
self
}
#[must_use]
pub fn inner(&self) -> &LocalLLMExtractor {
&self.inner
}
#[must_use]
pub fn max_concurrency(&self) -> usize {
self.max_concurrency
}
}
#[async_trait]
impl ObservationExtractor for BatchedLocalLLMExtractor {
async fn extract(
&self,
namespace_id: Uuid,
episode_id: Uuid,
messages: &[ExtractionMessage],
) -> ExtractionResult<Vec<ObservationMemory>> {
self.inner.extract(namespace_id, episode_id, messages).await
}
async fn extract_batch(
&self,
namespace_id: Uuid,
episode_ids: &[Uuid],
episodes: Vec<&[ExtractionMessage]>,
) -> ExtractionResult<Vec<Vec<ObservationMemory>>> {
if episode_ids.len() != episodes.len() {
return Err(ExtractionError::Other(format!(
"extract_batch: episode_ids ({}) and episodes ({}) length mismatch",
episode_ids.len(),
episodes.len(),
)));
}
if episodes.is_empty() {
return Ok(Vec::new());
}
let sem = Arc::new(Semaphore::new(self.max_concurrency));
let inner = &self.inner;
let futures = episode_ids
.iter()
.copied()
.zip(episodes)
.map(|(eid, msgs)| {
let sem = sem.clone();
async move {
let _permit = sem.acquire().await.map_err(|e| {
ExtractionError::Other(format!("semaphore unexpectedly closed: {e}"))
})?;
inner.extract(namespace_id, eid, msgs).await
}
});
let results = futures::future::join_all(futures).await;
results.into_iter().collect()
}
}
#[cfg(test)]
#[allow(
clippy::err_expect,
reason = "test code: `.err().expect()` mirrors the structure of preceding ok-path asserts"
)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
fn openai_response_body(text: &str) -> serde_json::Value {
serde_json::json!({
"id": "chatcmpl-test",
"object": "chat.completion",
"model": "local",
"choices": [{
"index": 0,
"message": {"role": "assistant", "content": text},
"finish_reason": "stop",
}],
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
})
}
fn msg(text: &str) -> ExtractionMessage {
ExtractionMessage {
role: "user".into(),
content: text.into(),
event_time: None,
}
}
#[test]
fn batched_default_concurrency_is_eight() {
let inner =
LocalLLMExtractor::new("http://example.com/v1", "qwen3.6-35b-a3b", None).unwrap();
let batched = BatchedLocalLLMExtractor::new(inner);
assert_eq!(batched.max_concurrency(), 4);
assert_eq!(BatchedLocalLLMExtractor::DEFAULT_MAX_CONCURRENCY, 4);
}
#[test]
fn batched_with_max_concurrency_clamps_zero_to_one() {
let inner =
LocalLLMExtractor::new("http://example.com/v1", "qwen3.6-35b-a3b", None).unwrap();
let batched = BatchedLocalLLMExtractor::new(inner).with_max_concurrency(0);
assert_eq!(batched.max_concurrency(), 1);
}
#[test]
fn batched_with_max_concurrency_overrides_default() {
let inner =
LocalLLMExtractor::new("http://example.com/v1", "qwen3.6-35b-a3b", None).unwrap();
let batched = BatchedLocalLLMExtractor::new(inner).with_max_concurrency(16);
assert_eq!(batched.max_concurrency(), 16);
}
#[tokio::test]
async fn batched_delegates_single_extract_to_inner() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/chat/completions"))
.respond_with(ResponseTemplate::new(200).set_body_json(openai_response_body("[]")))
.expect(1)
.mount(&server)
.await;
let inner = LocalLLMExtractor::new(server.uri(), "local", None).unwrap();
let batched = BatchedLocalLLMExtractor::new(inner);
let out = batched
.extract(Uuid::new_v4(), Uuid::new_v4(), &[msg("hello")])
.await
.expect("ok");
assert!(out.is_empty());
}
#[tokio::test]
async fn batched_returns_results_in_input_order() {
let server = MockServer::start().await;
for tag in ["alpha", "beta", "gamma", "delta"] {
let body = format!(
r#"[{{"entity_type":"tag_{tag}","instance":"x","action":"saw","quantity":1,"confidence":0.9}}]"#,
);
Mock::given(method("POST"))
.and(path("/v1/chat/completions"))
.and(wiremock::matchers::body_string_contains(tag))
.respond_with(
ResponseTemplate::new(200).set_body_json(openai_response_body(&body)),
)
.mount(&server)
.await;
}
let inner = LocalLLMExtractor::new(server.uri(), "local", None).unwrap();
let batched = BatchedLocalLLMExtractor::new(inner).with_max_concurrency(4);
let messages = ["alpha", "beta", "gamma", "delta"]
.iter()
.map(|t| [msg(t)])
.collect::<Vec<_>>();
let ids: Vec<Uuid> = messages.iter().map(|_| Uuid::new_v4()).collect();
let episodes: Vec<&[ExtractionMessage]> = messages
.iter()
.map(<[ExtractionMessage; 1]>::as_slice)
.collect();
let out = batched
.extract_batch(Uuid::new_v4(), &ids, episodes)
.await
.expect("ok");
assert_eq!(out.len(), 4);
for (i, tag) in ["alpha", "beta", "gamma", "delta"].iter().enumerate() {
assert_eq!(out[i].len(), 1, "episode {i} should have one observation");
assert_eq!(
out[i][0].entity_type,
format!("tag_{tag}"),
"episode {i} (input tag={tag}) returned wrong entity_type"
);
}
}
#[tokio::test]
async fn batched_fans_out_concurrent_calls() {
let server = MockServer::start().await;
let in_flight = Arc::new(AtomicUsize::new(0));
let peak = Arc::new(AtomicUsize::new(0));
let delay = Duration::from_millis(150);
let in_flight_resp = in_flight.clone();
let peak_resp = peak.clone();
Mock::given(method("POST"))
.and(path("/v1/chat/completions"))
.respond_with(move |_req: &wiremock::Request| {
let cur = in_flight_resp.fetch_add(1, Ordering::SeqCst) + 1;
peak_resp.fetch_max(cur, Ordering::SeqCst);
let in_flight_task = in_flight_resp.clone();
tokio::spawn(async move {
tokio::time::sleep(delay).await;
in_flight_task.fetch_sub(1, Ordering::SeqCst);
});
ResponseTemplate::new(200)
.set_body_json(openai_response_body("[]"))
.set_delay(delay)
})
.mount(&server)
.await;
let inner = LocalLLMExtractor::new(server.uri(), "local", None).unwrap();
let batched = BatchedLocalLLMExtractor::new(inner).with_max_concurrency(4);
let owned: Vec<[ExtractionMessage; 1]> =
(0..8).map(|i| [msg(&format!("ep{i}"))]).collect();
let ids: Vec<Uuid> = (0..8).map(|_| Uuid::new_v4()).collect();
let episodes: Vec<&[ExtractionMessage]> = owned
.iter()
.map(<[ExtractionMessage; 1]>::as_slice)
.collect();
let out = batched
.extract_batch(Uuid::new_v4(), &ids, episodes)
.await
.expect("ok");
assert_eq!(out.len(), 8);
let observed_peak = peak.load(Ordering::SeqCst);
assert!(
(2..=4).contains(&observed_peak),
"observed peak concurrency {observed_peak} should be in [2, 4] \
with max_concurrency=4 and 8 episodes (lower bound is loose to \
tolerate scheduler non-determinism; upper bound enforces the \
semaphore is actually clamping fan-out)"
);
}
#[tokio::test]
async fn batched_propagates_first_error() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/chat/completions"))
.respond_with(ResponseTemplate::new(500).set_body_string("server kaput"))
.mount(&server)
.await;
let inner = LocalLLMExtractor::new(server.uri(), "local", None).unwrap();
let batched = BatchedLocalLLMExtractor::new(inner).with_max_concurrency(2);
let owned: Vec<[ExtractionMessage; 1]> =
(0..3).map(|i| [msg(&format!("e{i}"))]).collect();
let ids: Vec<Uuid> = (0..3).map(|_| Uuid::new_v4()).collect();
let episodes: Vec<&[ExtractionMessage]> = owned
.iter()
.map(<[ExtractionMessage; 1]>::as_slice)
.collect();
let err = batched
.extract_batch(Uuid::new_v4(), &ids, episodes)
.await
.err()
.expect("expected an error");
match err {
ExtractionError::Transport(_) => {}
other => panic!("expected Transport, got {other:?}"),
}
}
#[tokio::test]
async fn batched_empty_input_returns_empty() {
let server = MockServer::start().await;
let inner = LocalLLMExtractor::new(server.uri(), "local", None).unwrap();
let batched = BatchedLocalLLMExtractor::new(inner);
let out = batched
.extract_batch(Uuid::new_v4(), &[], Vec::new())
.await
.expect("ok");
assert!(out.is_empty());
}
#[tokio::test]
async fn batched_rejects_length_mismatch() {
let inner = LocalLLMExtractor::new("http://example.com/v1", "local", None).unwrap();
let batched = BatchedLocalLLMExtractor::new(inner);
let m = msg("x");
let slice = std::slice::from_ref(&m);
let err = batched
.extract_batch(
Uuid::new_v4(),
&[Uuid::new_v4(), Uuid::new_v4()],
vec![slice],
)
.await
.err()
.expect("expected length-mismatch error");
match err {
ExtractionError::Other(msg) => {
assert!(msg.contains("length mismatch"), "unexpected msg: {msg}");
}
other => panic!("expected ExtractionError::Other, got {other:?}"),
}
}
#[allow(dead_code)]
fn batched_is_object_safe() {
fn takes_dyn(_: &dyn ObservationExtractor) {}
let inner = LocalLLMExtractor::new("http://x/v1", "local", None).unwrap();
takes_dyn(&BatchedLocalLLMExtractor::new(inner));
}
}
}
#[cfg(feature = "observation-extraction")]
pub use batched_localllm::BatchedLocalLLMExtractor;
pub async fn commit_extraction_for_episode<F, E>(
storage: &(dyn crate::storage::StorageTrait + Send + Sync),
extractor: &dyn ObservationExtractor,
namespace_id: Uuid,
episode_id: Uuid,
mut embed: F,
) -> usize
where
F: FnMut(&str) -> Result<Vec<f32>, E>,
E: std::fmt::Display,
{
let raw_messages = match storage.list_episodic_by_episode(namespace_id, episode_id) {
Ok(m) => m,
Err(e) => {
tracing::warn!(
target: "pensyve::observation",
error = %e,
episode_id = %episode_id,
"failed to load episode messages for extraction"
);
return 0;
}
};
if raw_messages.is_empty() {
return 0;
}
let extraction_messages: Vec<ExtractionMessage> = raw_messages
.iter()
.map(|m| ExtractionMessage {
role: String::new(),
content: m.content.clone(),
event_time: m.event_time,
})
.collect();
let observations = match extractor
.extract(namespace_id, episode_id, &extraction_messages)
.await
{
Ok(v) => v,
Err(e) => {
tracing::warn!(
target: "pensyve::observation",
error = %e,
episode_id = %episode_id,
"extractor failed — episode persists without observations"
);
return 0;
}
};
let mut persisted = 0usize;
for mut obs in observations {
match embed(&obs.content) {
Ok(v) => obs.embedding = v,
Err(e) => {
tracing::warn!(
target: "pensyve::observation",
error = %e,
observation_id = %obs.id,
"failed to embed observation content"
);
continue;
}
}
if let Err(e) = storage.save_observation(&obs) {
tracing::warn!(
target: "pensyve::observation",
error = %e,
observation_id = %obs.id,
"failed to persist observation"
);
continue;
}
persisted += 1;
}
persisted
}
pub async fn commit_extractions_for_episodes<F, E>(
storage: &(dyn crate::storage::StorageTrait + Send + Sync),
extractor: &dyn ObservationExtractor,
namespace_id: Uuid,
episode_ids: &[Uuid],
mut embed: F,
) -> usize
where
F: FnMut(&str) -> Result<Vec<f32>, E>,
E: std::fmt::Display,
{
if episode_ids.is_empty() {
return 0;
}
let mut surviving_ids: Vec<Uuid> = Vec::with_capacity(episode_ids.len());
let mut surviving_messages: Vec<Vec<ExtractionMessage>> = Vec::with_capacity(episode_ids.len());
for eid in episode_ids {
let raw_messages = match storage.list_episodic_by_episode(namespace_id, *eid) {
Ok(m) => m,
Err(e) => {
tracing::warn!(
target: "pensyve::observation",
error = %e,
episode_id = %eid,
"failed to load episode messages for extraction (batch)"
);
continue;
}
};
if raw_messages.is_empty() {
continue;
}
let extraction_messages: Vec<ExtractionMessage> = raw_messages
.iter()
.map(|m| ExtractionMessage {
role: String::new(),
content: m.content.clone(),
event_time: m.event_time,
})
.collect();
surviving_ids.push(*eid);
surviving_messages.push(extraction_messages);
}
if surviving_ids.is_empty() {
return 0;
}
let episode_slices: Vec<&[ExtractionMessage]> =
surviving_messages.iter().map(Vec::as_slice).collect();
let batch_results = match extractor
.extract_batch(namespace_id, &surviving_ids, episode_slices)
.await
{
Ok(v) => v,
Err(e) => {
tracing::warn!(
target: "pensyve::observation",
error = %e,
batch_size = surviving_ids.len(),
"batched extractor failed — no observations persisted for this batch"
);
return 0;
}
};
if batch_results.len() != surviving_ids.len() {
tracing::warn!(
target: "pensyve::observation",
expected = surviving_ids.len(),
got = batch_results.len(),
"batched extractor returned wrong-length result — dropping batch"
);
return 0;
}
let mut total_persisted = 0usize;
for (eid, observations) in surviving_ids.iter().zip(batch_results) {
let mut episode_persisted = 0usize;
for mut obs in observations {
match embed(&obs.content) {
Ok(v) => obs.embedding = v,
Err(e) => {
tracing::warn!(
target: "pensyve::observation",
error = %e,
observation_id = %obs.id,
episode_id = %eid,
"failed to embed observation content (batch)"
);
continue;
}
}
if let Err(e) = storage.save_observation(&obs) {
tracing::warn!(
target: "pensyve::observation",
error = %e,
observation_id = %obs.id,
episode_id = %eid,
"failed to persist observation (batch)"
);
continue;
}
episode_persisted += 1;
}
total_persisted += episode_persisted;
}
total_persisted
}
#[cfg(test)]
#[allow(
clippy::unnecessary_wraps,
reason = "test code: `fake_embed` mirrors the embedder closure signature so test fixtures can be swapped in without changing callers"
)]
mod tests {
use super::*;
#[tokio::test]
async fn noop_returns_empty() {
let extractor = NoopExtractor;
let ns = Uuid::new_v4();
let ep = Uuid::new_v4();
let msgs = vec![ExtractionMessage {
role: "user".into(),
content: "I played Assassin's Creed Odyssey for 70 hours".into(),
event_time: None,
}];
let out = extractor.extract(ns, ep, &msgs).await.unwrap();
assert!(out.is_empty());
}
#[tokio::test]
async fn noop_accepts_empty_messages() {
let extractor = NoopExtractor;
let out = extractor
.extract(Uuid::new_v4(), Uuid::new_v4(), &[])
.await
.unwrap();
assert!(out.is_empty());
}
#[allow(dead_code)]
fn trait_is_object_safe() {
fn takes_dyn(_: &dyn ObservationExtractor) {}
takes_dyn(&NoopExtractor);
}
#[derive(Debug, Clone)]
struct MockExtractor {
fixed: Vec<ObservationMemory>,
}
#[async_trait]
impl ObservationExtractor for MockExtractor {
async fn extract(
&self,
_namespace_id: Uuid,
_episode_id: Uuid,
_messages: &[ExtractionMessage],
) -> ExtractionResult<Vec<ObservationMemory>> {
Ok(self.fixed.clone())
}
}
#[tokio::test]
async fn mock_extractor_passes_through_fixed_output() {
let ns = Uuid::new_v4();
let ep = Uuid::new_v4();
let fixed = vec![ObservationMemory::new(
ns,
ep,
"game_played",
"AC Odyssey",
"played",
"User played AC Odyssey",
)];
let extractor = MockExtractor {
fixed: fixed.clone(),
};
let out = extractor.extract(ns, ep, &[]).await.unwrap();
assert_eq!(out.len(), 1);
assert_eq!(out[0].id, fixed[0].id);
}
#[derive(Debug, Default)]
struct RecordingExtractor {
calls: std::sync::Arc<std::sync::Mutex<Vec<Uuid>>>,
}
#[async_trait]
impl ObservationExtractor for RecordingExtractor {
async fn extract(
&self,
_namespace_id: Uuid,
episode_id: Uuid,
_messages: &[ExtractionMessage],
) -> ExtractionResult<Vec<ObservationMemory>> {
self.calls.lock().unwrap().push(episode_id);
Ok(Vec::new())
}
}
#[tokio::test]
async fn default_extract_batch_falls_through_to_per_episode_extract() {
let extractor = RecordingExtractor::default();
let ns = Uuid::new_v4();
let ids = [Uuid::new_v4(), Uuid::new_v4(), Uuid::new_v4()];
let msgs = [
ExtractionMessage {
role: "user".into(),
content: "ep0".into(),
event_time: None,
},
ExtractionMessage {
role: "user".into(),
content: "ep1".into(),
event_time: None,
},
ExtractionMessage {
role: "user".into(),
content: "ep2".into(),
event_time: None,
},
];
let episodes: Vec<&[ExtractionMessage]> = vec![
std::slice::from_ref(&msgs[0]),
std::slice::from_ref(&msgs[1]),
std::slice::from_ref(&msgs[2]),
];
let out = extractor
.extract_batch(ns, &ids, episodes)
.await
.expect("default extract_batch ok");
assert_eq!(out.len(), 3, "one Vec per input episode");
let recorded = extractor.calls.lock().unwrap().clone();
assert_eq!(
recorded.as_slice(),
ids.as_slice(),
"extract called per episode in input order"
);
}
#[tokio::test]
async fn default_extract_batch_rejects_length_mismatch() {
let extractor = RecordingExtractor::default();
let ns = Uuid::new_v4();
let ids = [Uuid::new_v4(), Uuid::new_v4()];
let msg = ExtractionMessage {
role: "user".into(),
content: "x".into(),
event_time: None,
};
let slice = std::slice::from_ref(&msg);
let episodes: Vec<&[ExtractionMessage]> = vec![slice, slice, slice];
let err = extractor
.extract_batch(ns, &ids, episodes)
.await
.expect_err("expected length-mismatch error");
match err {
ExtractionError::Other(msg) => {
assert!(msg.contains("length mismatch"), "unexpected msg: {msg}");
}
other => panic!("expected ExtractionError::Other, got {other:?}"),
}
assert!(
extractor.calls.lock().unwrap().is_empty(),
"no per-episode calls should have happened on rejection"
);
}
#[derive(Debug)]
struct FailingExtractor;
#[async_trait]
impl ObservationExtractor for FailingExtractor {
async fn extract(
&self,
_: Uuid,
_: Uuid,
_: &[ExtractionMessage],
) -> ExtractionResult<Vec<ObservationMemory>> {
Err(ExtractionError::Transport("boom".into()))
}
}
#[tokio::test]
async fn failing_extractor_returns_error() {
let extractor = FailingExtractor;
let result = extractor.extract(Uuid::new_v4(), Uuid::new_v4(), &[]).await;
assert!(matches!(result, Err(ExtractionError::Transport(_))));
}
use crate::storage::StorageTrait;
use crate::storage::sqlite::SqliteBackend;
use crate::types::{EpisodicMemory, Namespace, ObservationMemory};
use tempfile::TempDir;
fn fake_embed(_text: &str) -> Result<Vec<f32>, std::io::Error> {
Ok(vec![1.0_f32; 4])
}
fn setup_storage() -> (TempDir, SqliteBackend, Namespace, Uuid) {
let dir = TempDir::new().unwrap();
let db = SqliteBackend::open(dir.path()).unwrap();
let ns = Namespace::new("test-obs-ingest");
db.save_namespace(&ns).unwrap();
let episode_id = Uuid::new_v4();
let src = Uuid::new_v4();
let about = Uuid::new_v4();
for content in ["user: I played AC Odyssey", "user: I finished Dune"] {
let mut mem = EpisodicMemory::new(ns.id, episode_id, src, about, content);
mem.event_time = Some(Utc::now());
db.save_episodic(&mem).unwrap();
}
(dir, db, ns, episode_id)
}
#[tokio::test]
async fn commit_extraction_noop_persists_nothing() {
let (_dir, db, ns, ep) = setup_storage();
let persisted =
commit_extraction_for_episode(&db, &NoopExtractor, ns.id, ep, fake_embed).await;
assert_eq!(persisted, 0);
}
#[tokio::test]
async fn commit_extraction_persists_mock_observations_with_embeddings() {
let (_dir, db, ns, ep) = setup_storage();
let fixed = vec![
ObservationMemory::new(
ns.id,
ep,
"game_played",
"AC Odyssey",
"played",
"played AC Odyssey",
),
ObservationMemory::new(ns.id, ep, "book_read", "Dune", "read", "read Dune"),
];
let extractor = MockExtractor { fixed };
let persisted = commit_extraction_for_episode(&db, &extractor, ns.id, ep, fake_embed).await;
assert_eq!(persisted, 2);
let stored = db.list_observations_by_episode_ids(&[ep], 100).unwrap();
assert_eq!(stored.len(), 2);
for obs in &stored {
assert_eq!(obs.namespace_id, ns.id);
assert_eq!(obs.episode_id, ep);
assert_eq!(obs.embedding, vec![1.0_f32; 4]);
}
let instances: std::collections::HashSet<_> =
stored.iter().map(|o| o.instance.clone()).collect();
assert!(instances.contains("AC Odyssey"));
assert!(instances.contains("Dune"));
}
#[tokio::test]
async fn commit_extraction_swallows_extractor_failure() {
let (_dir, db, ns, ep) = setup_storage();
let persisted =
commit_extraction_for_episode(&db, &FailingExtractor, ns.id, ep, fake_embed).await;
assert_eq!(persisted, 0);
let raw = db.list_episodic_by_episode(ns.id, ep).unwrap();
assert_eq!(raw.len(), 2);
}
#[tokio::test]
async fn commit_extraction_swallows_embedding_failure() {
let (_dir, db, ns, ep) = setup_storage();
let extractor = MockExtractor {
fixed: vec![ObservationMemory::new(ns.id, ep, "x", "y", "z", "z y")],
};
let fail_embed = |_: &str| -> Result<Vec<f32>, std::io::Error> {
Err(std::io::Error::other("embedder down"))
};
let persisted = commit_extraction_for_episode(&db, &extractor, ns.id, ep, fail_embed).await;
assert_eq!(persisted, 0);
let stored = db.list_observations_by_episode_ids(&[ep], 100).unwrap();
assert!(stored.is_empty());
}
#[tokio::test]
async fn commit_extraction_skips_when_episode_has_no_messages() {
let dir = TempDir::new().unwrap();
let db = SqliteBackend::open(dir.path()).unwrap();
let ns = Namespace::new("empty");
db.save_namespace(&ns).unwrap();
let ep = Uuid::new_v4();
let extractor = MockExtractor {
fixed: vec![ObservationMemory::new(
ns.id, ep, "should", "not", "persist", "",
)],
};
let persisted = commit_extraction_for_episode(&db, &extractor, ns.id, ep, fake_embed).await;
assert_eq!(persisted, 0);
}
fn setup_two_episodes() -> (TempDir, SqliteBackend, Namespace, Uuid, Uuid) {
let dir = TempDir::new().unwrap();
let db = SqliteBackend::open(dir.path()).unwrap();
let ns = Namespace::new("test-batch-ingest");
db.save_namespace(&ns).unwrap();
let ep_a = Uuid::new_v4();
let ep_b = Uuid::new_v4();
let src = Uuid::new_v4();
let about = Uuid::new_v4();
for content in ["user: I played AC Odyssey", "user: I finished Dune"] {
let mut mem = EpisodicMemory::new(ns.id, ep_a, src, about, content);
mem.event_time = Some(Utc::now());
db.save_episodic(&mem).unwrap();
}
for content in ["user: I baked sourdough", "user: I read Foundation"] {
let mut mem = EpisodicMemory::new(ns.id, ep_b, src, about, content);
mem.event_time = Some(Utc::now());
db.save_episodic(&mem).unwrap();
}
(dir, db, ns, ep_a, ep_b)
}
#[derive(Debug, Clone)]
struct PerEpisodeMockExtractor {
by_episode: std::collections::HashMap<Uuid, Vec<ObservationMemory>>,
}
#[async_trait]
impl ObservationExtractor for PerEpisodeMockExtractor {
async fn extract(
&self,
_namespace_id: Uuid,
episode_id: Uuid,
_messages: &[ExtractionMessage],
) -> ExtractionResult<Vec<ObservationMemory>> {
Ok(self
.by_episode
.get(&episode_id)
.cloned()
.unwrap_or_default())
}
}
#[tokio::test]
async fn commit_extractions_batch_persists_per_episode_observations() {
let (_dir, db, ns, ep_a, ep_b) = setup_two_episodes();
let mut by_episode = std::collections::HashMap::new();
by_episode.insert(
ep_a,
vec![ObservationMemory::new(
ns.id,
ep_a,
"game_played",
"AC Odyssey",
"played",
"played AC Odyssey",
)],
);
by_episode.insert(
ep_b,
vec![ObservationMemory::new(
ns.id,
ep_b,
"food_made",
"sourdough",
"baked",
"baked sourdough",
)],
);
let extractor = PerEpisodeMockExtractor { by_episode };
let persisted =
commit_extractions_for_episodes(&db, &extractor, ns.id, &[ep_a, ep_b], fake_embed)
.await;
assert_eq!(persisted, 2);
let stored_a = db.list_observations_by_episode_ids(&[ep_a], 100).unwrap();
assert_eq!(stored_a.len(), 1);
assert_eq!(stored_a[0].instance, "AC Odyssey");
let stored_b = db.list_observations_by_episode_ids(&[ep_b], 100).unwrap();
assert_eq!(stored_b.len(), 1);
assert_eq!(stored_b[0].instance, "sourdough");
}
#[tokio::test]
async fn commit_extractions_batch_empty_input_is_noop() {
let (_dir, db, ns, _ep_a, _ep_b) = setup_two_episodes();
let extractor = NoopExtractor;
let persisted =
commit_extractions_for_episodes(&db, &extractor, ns.id, &[], fake_embed).await;
assert_eq!(persisted, 0);
}
#[tokio::test]
async fn commit_extractions_batch_swallows_extractor_failure() {
let (_dir, db, ns, ep_a, ep_b) = setup_two_episodes();
let persisted = commit_extractions_for_episodes(
&db,
&FailingExtractor,
ns.id,
&[ep_a, ep_b],
fake_embed,
)
.await;
assert_eq!(persisted, 0);
let stored_a = db.list_observations_by_episode_ids(&[ep_a], 100).unwrap();
let stored_b = db.list_observations_by_episode_ids(&[ep_b], 100).unwrap();
assert!(stored_a.is_empty());
assert!(stored_b.is_empty());
}
#[tokio::test]
async fn commit_extractions_batch_drops_episodes_with_no_messages() {
let (_dir, db, ns, ep_a, _ep_b) = setup_two_episodes();
let phantom_ep = Uuid::new_v4(); let mut by_episode = std::collections::HashMap::new();
by_episode.insert(
ep_a,
vec![ObservationMemory::new(ns.id, ep_a, "x", "y", "z", "z y")],
);
let extractor = PerEpisodeMockExtractor { by_episode };
let persisted = commit_extractions_for_episodes(
&db,
&extractor,
ns.id,
&[ep_a, phantom_ep],
fake_embed,
)
.await;
assert_eq!(persisted, 1);
let stored = db.list_observations_by_episode_ids(&[ep_a], 100).unwrap();
assert_eq!(stored.len(), 1);
}
}