Skip to main content

bamboo_server/server_tools/
memory.rs

1use std::collections::HashSet;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use serde::Deserialize;
6use serde_json::json;
7use tokio::sync::RwLock;
8
9use bamboo_agent_core::storage::Storage;
10use bamboo_agent_core::tools::{Tool, ToolError, ToolExecutionContext, ToolResult};
11use bamboo_agent_core::Session;
12use bamboo_memory::memory_store::{
13    DurableMemoryStatus, DurableMemoryType, MemoryQueryOptions, MemoryScope, MemoryStore,
14    MAX_MAX_CHARS, MAX_QUERY_LIMIT,
15};
16use bamboo_tools::tools::session_memory::{
17    execute_session_memory_action, SessionMemoryAction, MEMORY_SESSION_ACTION_NAMES,
18};
19
20type FilterTypeSet = (
21    Option<HashSet<DurableMemoryType>>,
22    Option<HashSet<DurableMemoryStatus>>,
23);
24
25#[derive(Clone)]
26pub struct MemoryTool {
27    sessions: Arc<RwLock<std::collections::HashMap<String, Session>>>,
28    storage: Arc<dyn Storage>,
29    memory_store: MemoryStore,
30}
31
32impl MemoryTool {
33    pub fn new(
34        sessions: Arc<RwLock<std::collections::HashMap<String, Session>>>,
35        storage: Arc<dyn Storage>,
36        data_dir: impl Into<std::path::PathBuf>,
37    ) -> Self {
38        Self {
39            sessions,
40            storage,
41            memory_store: MemoryStore::new(data_dir),
42        }
43    }
44
45    async fn session_for_context(&self, session_id: Option<&str>) -> Option<Session> {
46        let session_id = session_id?;
47        let in_memory = {
48            let sessions = self.sessions.read().await;
49            sessions.get(session_id).cloned()
50        };
51        match in_memory {
52            Some(session) => Some(session),
53            None => self.storage.load_session(session_id).await.ok().flatten(),
54        }
55    }
56
57    async fn resolve_project_key(
58        &self,
59        explicit: Option<&str>,
60        session_id: Option<&str>,
61    ) -> Option<String> {
62        if let Some(explicit) = explicit
63            .map(str::trim)
64            .filter(|value| !value.is_empty())
65            .map(ToString::to_string)
66        {
67            return Some(explicit);
68        }
69
70        if let Some(project_key) = self.memory_store.project_key_for_session(session_id) {
71            return Some(project_key);
72        }
73
74        self.session_for_context(session_id)
75            .await
76            .and_then(|session| session.metadata.get("workspace_path").cloned())
77            .map(std::path::PathBuf::from)
78            .map(|path| bamboo_memory::memory_store::project_key_from_path(&path))
79    }
80
81    fn parse_scope(scope: Option<&str>) -> Result<MemoryScope, ToolError> {
82        match scope
83            .unwrap_or("session")
84            .trim()
85            .to_ascii_lowercase()
86            .as_str()
87        {
88            "session" => Ok(MemoryScope::Session),
89            "project" => Ok(MemoryScope::Project),
90            "global" => Ok(MemoryScope::Global),
91            other => Err(ToolError::InvalidArguments(format!(
92                "invalid scope '{other}'; expected one of: session, project, global"
93            ))),
94        }
95    }
96
97    fn parse_type(value: &str) -> Result<DurableMemoryType, ToolError> {
98        match value.trim().to_ascii_lowercase().as_str() {
99            "user" => Ok(DurableMemoryType::User),
100            "feedback" => Ok(DurableMemoryType::Feedback),
101            "project" => Ok(DurableMemoryType::Project),
102            "reference" => Ok(DurableMemoryType::Reference),
103            other => Err(ToolError::InvalidArguments(format!(
104                "invalid type '{other}'; expected one of: user, feedback, project, reference"
105            ))),
106        }
107    }
108
109    fn parse_status(value: &str) -> Result<DurableMemoryStatus, ToolError> {
110        match value.trim().to_ascii_lowercase().as_str() {
111            "active" => Ok(DurableMemoryStatus::Active),
112            "stale" => Ok(DurableMemoryStatus::Stale),
113            "superseded" => Ok(DurableMemoryStatus::Superseded),
114            "contradicted" => Ok(DurableMemoryStatus::Contradicted),
115            "archived" => Ok(DurableMemoryStatus::Archived),
116            other => Err(ToolError::InvalidArguments(format!(
117                "invalid status '{other}'; expected one of: active, stale, superseded, contradicted, archived"
118            ))),
119        }
120    }
121
122    fn parse_query_filters(filters: Option<&QueryFilters>) -> Result<FilterTypeSet, ToolError> {
123        let filter_types = filters
124            .map(|value| {
125                value
126                    .r#type
127                    .iter()
128                    .map(|item| Self::parse_type(item))
129                    .collect::<Result<HashSet<_>, _>>()
130            })
131            .transpose()?;
132        let filter_statuses = filters
133            .map(|value| {
134                value
135                    .status
136                    .iter()
137                    .map(|item| Self::parse_status(item))
138                    .collect::<Result<HashSet<_>, _>>()
139            })
140            .transpose()?;
141        Ok((filter_types, filter_statuses))
142    }
143
144    fn parse_merge_mode(value: Option<&str>) -> Result<Option<String>, ToolError> {
145        let Some(mode) = value.map(str::trim).filter(|value| !value.is_empty()) else {
146            return Ok(None);
147        };
148        let normalized = mode.to_ascii_lowercase();
149        match normalized.as_str() {
150            "semantic_merge" | "merge" | "contradict" => Ok(Some(normalized)),
151            other => Err(ToolError::InvalidArguments(format!(
152                "invalid merge mode '{other}'; expected one of: merge, semantic_merge, contradict"
153            ))),
154        }
155    }
156}
157
158#[derive(Debug, Deserialize)]
159#[serde(tag = "action", rename_all = "snake_case")]
160enum MemoryArgs {
161    SessionRead {
162        #[serde(default)]
163        topic: Option<String>,
164        #[serde(default)]
165        options: Option<MemoryActionOptions>,
166    },
167    SessionAppend {
168        #[serde(default)]
169        topic: Option<String>,
170        content: String,
171    },
172    SessionReplace {
173        #[serde(default)]
174        topic: Option<String>,
175        content: String,
176    },
177    SessionClear {
178        #[serde(default)]
179        topic: Option<String>,
180    },
181    SessionListTopics,
182    Query {
183        scope: String,
184        #[serde(default)]
185        query: Option<String>,
186        #[serde(default)]
187        filters: Option<QueryFilters>,
188        #[serde(default)]
189        project_key: Option<String>,
190        #[serde(default)]
191        options: Option<MemoryActionOptions>,
192    },
193    Get {
194        id: String,
195        #[serde(default)]
196        project_key: Option<String>,
197        #[serde(default)]
198        options: Option<MemoryActionOptions>,
199    },
200    Write {
201        scope: String,
202        #[serde(rename = "type")]
203        r#type: String,
204        title: String,
205        content: String,
206        #[serde(default)]
207        tags: Vec<String>,
208        #[serde(default)]
209        project_key: Option<String>,
210        #[serde(default)]
211        options: Option<WriteOptions>,
212    },
213    Merge {
214        id: String,
215        content: String,
216        #[serde(default)]
217        tags: Vec<String>,
218        #[serde(default)]
219        project_key: Option<String>,
220        #[serde(default)]
221        source_memory_ids: Vec<String>,
222        #[serde(default)]
223        mode: Option<String>,
224        #[serde(default)]
225        reason: Option<String>,
226    },
227    Purge {
228        #[serde(default)]
229        id: Option<String>,
230        #[serde(default)]
231        scope: Option<String>,
232        #[serde(default)]
233        reason: Option<String>,
234        #[serde(default)]
235        project_key: Option<String>,
236        #[serde(default)]
237        filters: Option<QueryFilters>,
238        #[serde(default)]
239        mode: Option<String>,
240    },
241    Inspect {
242        scope: String,
243        #[serde(default)]
244        project_key: Option<String>,
245    },
246    Rebuild {
247        scope: String,
248        #[serde(default)]
249        project_key: Option<String>,
250    },
251}
252
253#[derive(Debug, Deserialize, Default)]
254struct MemoryActionOptions {
255    #[serde(default)]
256    limit: Option<usize>,
257    #[serde(default)]
258    max_chars: Option<usize>,
259    #[serde(default)]
260    cursor: Option<String>,
261    #[serde(default)]
262    include_related: Option<bool>,
263}
264
265#[derive(Debug, Deserialize, Default)]
266struct QueryFilters {
267    #[serde(default)]
268    r#type: Vec<String>,
269    #[serde(default)]
270    status: Vec<String>,
271}
272
273#[derive(Debug, Deserialize, Default)]
274struct WriteOptions {
275    #[serde(default)]
276    allow_merge_if_similar: Option<bool>,
277}
278
279#[async_trait]
280impl Tool for MemoryTool {
281    fn name(&self) -> &str {
282        "memory"
283    }
284
285    fn description(&self) -> &str {
286        "Unified memory management tool for Bamboo. Use session_* actions for session continuity notes, and query/get/write/purge/inspect/rebuild for durable project/global memory backed by canonical topic files and derived indexes."
287    }
288
289    fn parameters_schema(&self) -> serde_json::Value {
290        json!({
291            "type": "object",
292            "properties": {
293                "action": {
294                    "type": "string",
295                    "enum": [
296                        "session_read",
297                        "session_append",
298                        "session_replace",
299                        "session_clear",
300                        "session_list_topics",
301                        "query",
302                        "get",
303                        "write",
304                        "merge",
305                        "purge",
306                        "inspect",
307                        "rebuild"
308                    ]
309                },
310                "scope": {"type": "string", "enum": ["session", "project", "global"]},
311                "project_key": {"type": "string"},
312                "topic": {"type": "string"},
313                "id": {"type": "string"},
314                "query": {"type": "string"},
315                "type": {"type": "string", "enum": ["user", "feedback", "project", "reference"]},
316                "title": {"type": "string"},
317                "content": {"type": "string"},
318                "tags": {"type": "array", "items": {"type": "string"}},
319                "filters": {"type": "object"},
320                "options": {"type": "object"},
321                "reason": {"type": "string"}
322            },
323            "required": ["action"]
324        })
325    }
326
327    fn call_mutability(&self, args: &serde_json::Value) -> bamboo_tools::ToolMutability {
328        let action = args
329            .get("action")
330            .and_then(|value| value.as_str())
331            .unwrap_or("")
332            .trim()
333            .to_ascii_lowercase();
334        match action.as_str() {
335            "session_read" | "session_list_topics" | "query" | "get" | "inspect" => {
336                bamboo_tools::ToolMutability::ReadOnly
337            }
338            _ => bamboo_tools::ToolMutability::Mutating,
339        }
340    }
341
342    fn call_concurrency_safe(&self, args: &serde_json::Value) -> bool {
343        matches!(
344            self.call_mutability(args),
345            bamboo_tools::ToolMutability::ReadOnly
346        )
347    }
348
349    async fn execute(&self, args: serde_json::Value) -> Result<ToolResult, ToolError> {
350        self.execute_with_context(args, ToolExecutionContext::none("tool_call"))
351            .await
352    }
353
354    async fn execute_with_context(
355        &self,
356        args: serde_json::Value,
357        ctx: ToolExecutionContext<'_>,
358    ) -> Result<ToolResult, ToolError> {
359        let session_id = ctx.session_id.ok_or_else(|| {
360            ToolError::Execution("memory requires a session_id in tool context".to_string())
361        })?;
362
363        let parsed: MemoryArgs = serde_json::from_value(args).map_err(|error| {
364            ToolError::InvalidArguments(format!("Invalid memory args: {error}"))
365        })?;
366
367        match parsed {
368            MemoryArgs::SessionRead { topic, options } => {
369                let max_chars = options.and_then(|value| value.max_chars);
370                execute_session_memory_action(
371                    &self.memory_store,
372                    session_id,
373                    SessionMemoryAction::Read,
374                    topic.as_deref(),
375                    None,
376                    max_chars,
377                    MEMORY_SESSION_ACTION_NAMES,
378                )
379                .await
380            }
381            MemoryArgs::SessionAppend { topic, content } => {
382                execute_session_memory_action(
383                    &self.memory_store,
384                    session_id,
385                    SessionMemoryAction::Append,
386                    topic.as_deref(),
387                    Some(content.as_str()),
388                    None,
389                    MEMORY_SESSION_ACTION_NAMES,
390                )
391                .await
392            }
393            MemoryArgs::SessionReplace { topic, content } => {
394                execute_session_memory_action(
395                    &self.memory_store,
396                    session_id,
397                    SessionMemoryAction::Replace,
398                    topic.as_deref(),
399                    Some(content.as_str()),
400                    None,
401                    MEMORY_SESSION_ACTION_NAMES,
402                )
403                .await
404            }
405            MemoryArgs::SessionClear { topic } => {
406                execute_session_memory_action(
407                    &self.memory_store,
408                    session_id,
409                    SessionMemoryAction::Clear,
410                    topic.as_deref(),
411                    None,
412                    None,
413                    MEMORY_SESSION_ACTION_NAMES,
414                )
415                .await
416            }
417            MemoryArgs::SessionListTopics => {
418                execute_session_memory_action(
419                    &self.memory_store,
420                    session_id,
421                    SessionMemoryAction::ListTopics,
422                    None,
423                    None,
424                    None,
425                    MEMORY_SESSION_ACTION_NAMES,
426                )
427                .await
428            }
429            MemoryArgs::Query {
430                scope,
431                query,
432                filters,
433                project_key,
434                options,
435            } => {
436                let scope = Self::parse_scope(Some(&scope))?;
437                if scope == MemoryScope::Session {
438                    return Err(ToolError::InvalidArguments(
439                        "query supports durable scopes only; use session_read/session_list_topics for session scope"
440                            .to_string(),
441                    ));
442                }
443                let project_key = self
444                    .resolve_project_key(project_key.as_deref(), Some(session_id))
445                    .await;
446                let options = MemoryQueryOptions {
447                    limit: options
448                        .as_ref()
449                        .and_then(|value| value.limit)
450                        .map(|value| value.min(MAX_QUERY_LIMIT)),
451                    max_chars: options
452                        .as_ref()
453                        .and_then(|value| value.max_chars)
454                        .map(|value| value.min(MAX_MAX_CHARS)),
455                    cursor: options.as_ref().and_then(|value| value.cursor.clone()),
456                    include_related: options
457                        .as_ref()
458                        .and_then(|value| value.include_related)
459                        .unwrap_or(false),
460                };
461                let (filter_types, filter_statuses) = Self::parse_query_filters(filters.as_ref())?;
462                let result = self
463                    .memory_store
464                    .query_scope(
465                        scope,
466                        project_key.as_deref(),
467                        query.as_deref(),
468                        filter_types.as_ref(),
469                        filter_statuses.as_ref(),
470                        &options,
471                    )
472                    .await
473                    .map_err(|error| {
474                        ToolError::Execution(format!("Failed to query memory: {error}"))
475                    })?;
476                Ok(ToolResult {
477                    success: true,
478                    result: json!({
479                        "action": "query",
480                        "success": true,
481                        "data": result,
482                        "summary": bamboo_memory::memory_store::summary_json(result.returned_count, result.matched_count),
483                        "warnings": [],
484                    }).to_string(),
485                    display_preference: Some("json".to_string()),
486                })
487            }
488            MemoryArgs::Get {
489                id,
490                project_key,
491                options,
492            } => {
493                let project_key = self
494                    .resolve_project_key(project_key.as_deref(), Some(session_id))
495                    .await;
496                let max_chars = options
497                    .and_then(|value| value.max_chars)
498                    .unwrap_or(MAX_MAX_CHARS)
499                    .min(MAX_MAX_CHARS);
500                let Some(mut doc) = self
501                    .memory_store
502                    .get_memory(id.trim(), project_key.as_deref())
503                    .await
504                    .map_err(|error| {
505                        ToolError::Execution(format!("Failed to get memory: {error}"))
506                    })?
507                else {
508                    return Err(ToolError::Execution(format!(
509                        "memory not found: {}",
510                        id.trim()
511                    )));
512                };
513                let (body, truncated) =
514                    bamboo_memory::memory_store::truncate_chars(&doc.body, max_chars);
515                doc.body = body;
516                Ok(ToolResult {
517                    success: true,
518                    result: json!({
519                        "action": "get",
520                        "id": doc.frontmatter.id,
521                        "memory": {
522                            "frontmatter": doc.frontmatter,
523                            "body": doc.body,
524                            "path": doc.path,
525                            "body_truncated": truncated,
526                        }
527                    })
528                    .to_string(),
529                    display_preference: Some("json".to_string()),
530                })
531            }
532            MemoryArgs::Write {
533                scope,
534                r#type,
535                title,
536                content,
537                tags,
538                project_key,
539                options,
540            } => {
541                let scope = Self::parse_scope(Some(&scope))?;
542                if scope == MemoryScope::Session {
543                    return Err(ToolError::InvalidArguments(
544                        "write supports durable scopes only; use session_replace/session_append for session scope"
545                            .to_string(),
546                    ));
547                }
548                let project_key = self
549                    .resolve_project_key(project_key.as_deref(), Some(session_id))
550                    .await;
551                let doc = self
552                    .memory_store
553                    .write_memory(
554                        scope,
555                        project_key.as_deref(),
556                        Self::parse_type(&r#type)?,
557                        &title,
558                        &content,
559                        &tags,
560                        Some(session_id),
561                        "main-model",
562                        options
563                            .and_then(|value| value.allow_merge_if_similar)
564                            .unwrap_or(true),
565                    )
566                    .await
567                    .map_err(|error| {
568                        ToolError::Execution(format!("Failed to write memory: {error}"))
569                    })?;
570                Ok(ToolResult {
571                    success: true,
572                    result: json!({
573                        "action": "write",
574                        "memory": {
575                            "id": doc.frontmatter.id,
576                            "title": doc.frontmatter.title,
577                            "type": doc.frontmatter.r#type,
578                            "scope": doc.frontmatter.scope,
579                            "status": doc.frontmatter.status,
580                            "project_key": doc.frontmatter.project_key,
581                            "path": doc.path,
582                        }
583                    })
584                    .to_string(),
585                    display_preference: Some("json".to_string()),
586                })
587            }
588            MemoryArgs::Merge {
589                id,
590                content,
591                tags,
592                project_key,
593                source_memory_ids,
594                mode,
595                reason,
596            } => {
597                let project_key = self
598                    .resolve_project_key(project_key.as_deref(), Some(session_id))
599                    .await;
600                let mode = Self::parse_merge_mode(mode.as_deref())?;
601                if matches!(mode.as_deref(), Some("contradict")) {
602                    let Some(result) = self
603                        .memory_store
604                        .mark_memory_contradicted(
605                            id.trim(),
606                            project_key.as_deref(),
607                            &source_memory_ids,
608                            reason.as_deref().or(Some(content.trim())),
609                            Some(session_id),
610                            "main-model",
611                        )
612                        .await
613                        .map_err(|error| {
614                            ToolError::Execution(format!("Failed to contradict memory: {error}"))
615                        })?
616                    else {
617                        return Err(ToolError::Execution(format!(
618                            "memory not found: {}",
619                            id.trim()
620                        )));
621                    };
622                    Ok(ToolResult {
623                        success: true,
624                        result: json!({
625                            "action": "merge",
626                            "mode": "contradict",
627                            "data": result,
628                        })
629                        .to_string(),
630                        display_preference: Some("json".to_string()),
631                    })
632                } else {
633                    let Some(result) = self
634                        .memory_store
635                        .merge_memory(
636                            id.trim(),
637                            project_key.as_deref(),
638                            &content,
639                            &tags,
640                            Some(session_id),
641                            "main-model",
642                            &source_memory_ids,
643                        )
644                        .await
645                        .map_err(|error| {
646                            ToolError::Execution(format!("Failed to merge memory: {error}"))
647                        })?
648                    else {
649                        return Err(ToolError::Execution(format!(
650                            "memory not found: {}",
651                            id.trim()
652                        )));
653                    };
654                    Ok(ToolResult {
655                        success: true,
656                        result: json!({
657                            "action": "merge",
658                            "mode": mode.unwrap_or_else(|| "merge".to_string()),
659                            "data": result,
660                        })
661                        .to_string(),
662                        display_preference: Some("json".to_string()),
663                    })
664                }
665            }
666            MemoryArgs::Purge {
667                id,
668                scope,
669                reason,
670                project_key,
671                filters,
672                mode,
673            } => {
674                let mode = match mode
675                    .as_deref()
676                    .map(str::trim)
677                    .filter(|value| !value.is_empty())
678                {
679                    Some(value) => Self::parse_status(value)?,
680                    None => DurableMemoryStatus::Archived,
681                };
682                let project_key = self
683                    .resolve_project_key(project_key.as_deref(), Some(session_id))
684                    .await;
685
686                if let Some(id) = id
687                    .as_deref()
688                    .map(str::trim)
689                    .filter(|value| !value.is_empty())
690                {
691                    let Some(doc) = self
692                        .memory_store
693                        .archive_memory(id, project_key.as_deref(), mode, reason.as_deref())
694                        .await
695                        .map_err(|error| {
696                            ToolError::Execution(format!("Failed to purge memory: {error}"))
697                        })?
698                    else {
699                        return Err(ToolError::Execution(format!("memory not found: {}", id)));
700                    };
701                    Ok(ToolResult {
702                        success: true,
703                        result: json!({
704                            "action": "purge",
705                            "id": doc.frontmatter.id,
706                            "status": doc.frontmatter.status,
707                        })
708                        .to_string(),
709                        display_preference: Some("json".to_string()),
710                    })
711                } else {
712                    let scope = Self::parse_scope(scope.as_deref())?;
713                    if scope == MemoryScope::Session {
714                        return Err(ToolError::InvalidArguments(
715                            "purge supports durable scopes only in v1".to_string(),
716                        ));
717                    }
718                    let (filter_types, filter_statuses) =
719                        Self::parse_query_filters(filters.as_ref())?;
720                    let result = self
721                        .memory_store
722                        .purge_memories(
723                            scope,
724                            project_key.as_deref(),
725                            filter_types.as_ref(),
726                            filter_statuses.as_ref(),
727                            mode,
728                            reason.as_deref(),
729                        )
730                        .await
731                        .map_err(|error| {
732                            ToolError::Execution(format!("Failed to purge memory: {error}"))
733                        })?;
734                    Ok(ToolResult {
735                        success: true,
736                        result: json!({
737                            "action": "purge",
738                            "data": result,
739                        })
740                        .to_string(),
741                        display_preference: Some("json".to_string()),
742                    })
743                }
744            }
745            MemoryArgs::Inspect { scope, project_key } => {
746                let scope = Self::parse_scope(Some(&scope))?;
747                if scope == MemoryScope::Session {
748                    return Err(ToolError::InvalidArguments(
749                        "inspect supports durable scopes only in v1".to_string(),
750                    ));
751                }
752                let project_key = self
753                    .resolve_project_key(project_key.as_deref(), Some(session_id))
754                    .await;
755                let result = self
756                    .memory_store
757                    .inspect_scope(scope, project_key.as_deref())
758                    .await
759                    .map_err(|error| {
760                        ToolError::Execution(format!("Failed to inspect memory: {error}"))
761                    })?;
762                Ok(ToolResult {
763                    success: true,
764                    result: json!({
765                        "action": "inspect",
766                        "data": result,
767                    })
768                    .to_string(),
769                    display_preference: Some("json".to_string()),
770                })
771            }
772            MemoryArgs::Rebuild { scope, project_key } => {
773                let scope = Self::parse_scope(Some(&scope))?;
774                if scope == MemoryScope::Session {
775                    return Err(ToolError::InvalidArguments(
776                        "rebuild supports durable scopes only in v1".to_string(),
777                    ));
778                }
779                let project_key = self
780                    .resolve_project_key(project_key.as_deref(), Some(session_id))
781                    .await;
782                self.memory_store
783                    .rebuild_scope(scope, project_key.as_deref())
784                    .await
785                    .map_err(|error| {
786                        ToolError::Execution(format!("Failed to rebuild memory artifacts: {error}"))
787                    })?;
788                let inspect = self
789                    .memory_store
790                    .inspect_scope(scope, project_key.as_deref())
791                    .await
792                    .map_err(|error| {
793                        ToolError::Execution(format!("Failed to inspect rebuilt memory: {error}"))
794                    })?;
795                Ok(ToolResult {
796                    success: true,
797                    result: json!({
798                        "action": "rebuild",
799                        "scope": scope,
800                        "project_key": project_key,
801                        "data": inspect,
802                    })
803                    .to_string(),
804                    display_preference: Some("json".to_string()),
805                })
806            }
807        }
808    }
809}
810
811#[cfg(test)]
812mod tests {
813    use super::*;
814
815    use std::collections::HashMap;
816
817    use tokio::sync::RwLock;
818
819    #[derive(Default)]
820    struct TestStorage {
821        sessions: RwLock<HashMap<String, Session>>,
822    }
823
824    #[async_trait]
825    impl Storage for TestStorage {
826        async fn save_session(&self, session: &Session) -> std::io::Result<()> {
827            self.sessions
828                .write()
829                .await
830                .insert(session.id.clone(), session.clone());
831            Ok(())
832        }
833
834        async fn load_session(&self, session_id: &str) -> std::io::Result<Option<Session>> {
835            Ok(self.sessions.read().await.get(session_id).cloned())
836        }
837
838        async fn delete_session(&self, session_id: &str) -> std::io::Result<bool> {
839            Ok(self.sessions.write().await.remove(session_id).is_some())
840        }
841    }
842
843    fn test_context<'a>(session_id: &'a str) -> ToolExecutionContext<'a> {
844        ToolExecutionContext {
845            session_id: Some(session_id),
846            tool_call_id: "tool-call-1",
847            event_tx: None,
848            available_tool_schemas: None,
849        }
850    }
851
852    fn build_memory_tool(data_dir: &std::path::Path) -> MemoryTool {
853        let sessions = Arc::new(RwLock::new(HashMap::new()));
854        let storage: Arc<dyn Storage> = Arc::new(TestStorage::default());
855        MemoryTool::new(sessions, storage, data_dir)
856    }
857
858    #[tokio::test]
859    async fn memory_session_actions_share_read_shape_and_limits() {
860        let dir = tempfile::tempdir().expect("tempdir");
861        let tool = build_memory_tool(dir.path());
862
863        tool.execute_with_context(
864            json!({"action":"session_replace","topic":"default","content":"x".repeat(32)}),
865            test_context("session-1"),
866        )
867        .await
868        .expect("session replace should succeed");
869
870        let read = tool
871            .execute_with_context(
872                json!({"action":"session_read","topic":"default","options":{"max_chars":8}}),
873                test_context("session-1"),
874            )
875            .await
876            .expect("session read should succeed");
877        let value: serde_json::Value = serde_json::from_str(&read.result).expect("valid json");
878        assert_eq!(value["action"], "session_read");
879        assert_eq!(value["length_chars"], 32);
880        assert_eq!(value["body_truncated"], true);
881        assert_eq!(value["content"].as_str().unwrap().chars().count(), 8);
882    }
883
884    #[tokio::test]
885    async fn memory_session_append_enforces_shared_limit() {
886        let dir = tempfile::tempdir().expect("tempdir");
887        let tool = build_memory_tool(dir.path());
888
889        tool.execute_with_context(
890            json!({
891                "action":"session_replace",
892                "topic":"limit",
893                "content":"x".repeat(bamboo_tools::tools::session_memory::MAX_SESSION_NOTE_CHARS - 1)
894            }),
895            test_context("session-2"),
896        )
897        .await
898        .expect("session replace near limit should succeed");
899
900        let err = tool
901            .execute_with_context(
902                json!({"action":"session_append","topic":"limit","content":"y"}),
903                test_context("session-2"),
904            )
905            .await
906            .expect_err("session append should fail");
907        let message = err.to_string();
908        assert!(message.contains("session note would exceed the limit"));
909        assert!(message.contains("action=session_read"));
910        assert!(message.contains("action=session_replace"));
911    }
912
913    #[tokio::test]
914    async fn memory_session_list_topics_includes_count() {
915        let dir = tempfile::tempdir().expect("tempdir");
916        let tool = build_memory_tool(dir.path());
917
918        tool.execute_with_context(
919            json!({"action":"session_append","topic":"alpha","content":"A"}),
920            test_context("session-3"),
921        )
922        .await
923        .expect("session append should succeed");
924        tool.execute_with_context(
925            json!({"action":"session_append","topic":"beta","content":"B"}),
926            test_context("session-3"),
927        )
928        .await
929        .expect("session append should succeed");
930
931        let list = tool
932            .execute_with_context(
933                json!({"action":"session_list_topics"}),
934                test_context("session-3"),
935            )
936            .await
937            .expect("session list topics should succeed");
938        let value: serde_json::Value = serde_json::from_str(&list.result).expect("valid json");
939        assert_eq!(value["action"], "session_list_topics");
940        assert_eq!(value["count"], 2);
941    }
942}