1use std::io::{self, BufReader};
2use std::path::PathBuf;
3use std::str::FromStr;
4
5use anyhow::{Context, Result, bail};
6use clap::{Parser, Subcommand};
7use serde::Serialize;
8
9use crate::embedding::configure_embeddings_path;
10use crate::explorer;
11use crate::mcp;
12use crate::model::{ExpirationCondition, MemoryMode};
13use crate::store::{MemoryStore, SearchOptions, SetMemory, default_database_path};
14
15#[derive(Debug, Parser)]
16#[command(version, about = "A smart memory management system for agents")]
17pub struct Cli {
18 #[arg(long, global = true, env = "MII_MEMORY_DB", value_name = "PATH")]
19 db: Option<PathBuf>,
20
21 #[arg(
22 long,
23 global = true,
24 env = "MII_MEMORY_EMBEDDINGS",
25 value_name = "PATH"
26 )]
27 embeddings: Option<PathBuf>,
28
29 #[command(subcommand)]
30 command: Command,
31}
32
33#[derive(Debug, Subcommand)]
34enum Command {
35 Set(SetCommand),
36 Get(GetCommand),
37 ListTags(ListTagsCommand),
38 Alert(AlertCommand),
39 Alerts(AlertsCommand),
40 Mcp,
41 Explorer(ExplorerCommand),
42}
43
44#[derive(Debug, Parser)]
45struct AlertCommand {
46 #[command(subcommand)]
47 command: AlertSubcommand,
48}
49
50#[derive(Debug, Subcommand)]
51enum AlertSubcommand {
52 Set(AlertSetCommand),
53}
54
55#[derive(Debug, Parser)]
56struct AlertSetCommand {
57 session_ref: String,
58 content: String,
59}
60
61#[derive(Debug, Parser)]
62struct AlertsCommand {
63 session_ref: String,
64}
65
66#[derive(Debug, Parser)]
67struct SetCommand {
68 content: String,
69
70 #[arg(long, value_enum, default_value_t = MemoryMode::Global)]
71 mode: MemoryMode,
72
73 #[arg(value_name = "MODE_REF")]
74 mode_ref: Option<String>,
75
76 #[arg(short = 't', long = "tag", required = true)]
77 tags: Vec<String>,
78
79 #[arg(
80 long = "expiration-condition",
81 value_names = ["CONDITION", "VALUE"],
82 num_args = 2
83 )]
84 expiration: Option<Vec<String>>,
85
86 #[arg(long)]
87 metadata: Option<String>,
88}
89
90#[derive(Debug, Parser)]
91struct GetCommand {
92 query: String,
93
94 #[arg(short = 't', long = "tag", alias = "p-tag")]
95 positive_tags: Vec<String>,
96
97 #[arg(long = "n-tag")]
98 negative_tags: Vec<String>,
99
100 #[arg(long, default_value_t = 10)]
101 limit: usize,
102
103 #[arg(long, default_value_t = 0)]
104 offset: usize,
105
106 #[arg(long, value_enum)]
107 mode: Option<MemoryMode>,
108
109 #[arg(long)]
110 mode_ref: Option<String>,
111}
112
113#[derive(Debug, Parser)]
114struct ListTagsCommand {
115 #[arg(long)]
116 filter: Option<String>,
117
118 #[arg(long)]
119 json: bool,
120}
121
122#[derive(Debug, Parser)]
123struct ExplorerCommand {
124 #[arg(long, default_value = "127.0.0.1")]
125 host: String,
126
127 #[arg(long, default_value_t = 4117)]
128 port: u16,
129}
130
131#[derive(Debug, Serialize)]
132struct SetOutput {
133 id: i64,
134}
135
136pub fn run() -> Result<()> {
137 let cli = Cli::parse();
138 if let Some(embeddings_path) = cli.embeddings {
139 configure_embeddings_path(embeddings_path)?;
140 }
141
142 let database_path = cli.db.unwrap_or_else(default_database_path);
143 let mut store = MemoryStore::open(&database_path)?;
144
145 match cli.command {
146 Command::Set(command) => {
147 let id = store.set(command.try_into()?)?;
148 println!("{}", serde_json::to_string(&SetOutput { id })?);
149 }
150 Command::Get(command) => {
151 for result in store.get(command.into())? {
152 println!("{}", serde_json::to_string(&result)?);
153 }
154 }
155 Command::ListTags(command) => {
156 let tags = store.list_tags(command.filter.as_deref())?;
157 for tag in tags {
158 if command.json {
159 println!("{}", serde_json::to_string(&tag)?);
160 } else {
161 println!("{}", tag.tag);
162 }
163 }
164 }
165 Command::Alert(command) => match command.command {
166 AlertSubcommand::Set(command) => {
167 let id = store.set_alert(command.session_ref, command.content)?;
168 println!("{}", serde_json::to_string(&SetOutput { id })?);
169 }
170 },
171 Command::Alerts(command) => {
172 for alert in store.get_alerts(command.session_ref)? {
173 println!("{}", serde_json::to_string(&alert)?);
174 }
175 }
176 Command::Mcp => {
177 let input = BufReader::new(io::stdin().lock());
178 let output = io::stdout().lock();
179 mcp::serve(store, input, output)?;
180 }
181 Command::Explorer(command) => {
182 drop(store);
183 explorer::serve(database_path, &command.host, command.port)?;
184 }
185 }
186
187 Ok(())
188}
189
190impl TryFrom<SetCommand> for SetMemory {
191 type Error = anyhow::Error;
192
193 fn try_from(command: SetCommand) -> Result<Self> {
194 let (expiration_condition, expiration_value) = parse_expiration_pair(command.expiration)?;
195
196 Ok(Self {
197 content: command.content,
198 mode: command.mode,
199 mode_ref: command.mode_ref,
200 tags: command.tags,
201 expiration_condition,
202 expiration_value,
203 metadata: command.metadata,
204 })
205 }
206}
207
208impl From<GetCommand> for SearchOptions {
209 fn from(command: GetCommand) -> Self {
210 Self {
211 query: command.query,
212 positive_tags: command.positive_tags,
213 negative_tags: command.negative_tags,
214 limit: command.limit,
215 offset: command.offset,
216 mode: command.mode,
217 mode_ref: command.mode_ref,
218 }
219 }
220}
221
222pub fn parse_expiration_pair(
223 expiration: Option<Vec<String>>,
224) -> Result<(Option<ExpirationCondition>, Option<String>)> {
225 let Some(expiration) = expiration else {
226 return Ok((None, None));
227 };
228
229 let [condition, value] = expiration.as_slice() else {
230 bail!("--expiration-condition expects CONDITION and VALUE");
231 };
232
233 Ok((
234 Some(
235 ExpirationCondition::from_str(condition)
236 .with_context(|| format!("invalid expiration condition {condition}"))?,
237 ),
238 Some(value.clone()),
239 ))
240}