Skip to main content

mii_memory/
cli.rs

1use std::io::{self, BufReader};
2use std::path::PathBuf;
3use std::str::FromStr;
4
5use anyhow::{Context, Result, bail};
6use clap::{Parser, Subcommand};
7use serde::Serialize;
8
9use crate::embedding::configure_embeddings_path;
10use crate::explorer;
11use crate::mcp;
12use crate::model::{ExpirationCondition, MemoryMode};
13use crate::store::{MemoryStore, SearchOptions, SetMemory, default_database_path};
14
15#[derive(Debug, Parser)]
16#[command(version, about = "A smart memory management system for agents")]
17pub struct Cli {
18    #[arg(long, global = true, env = "MII_MEMORY_DB", value_name = "PATH")]
19    db: Option<PathBuf>,
20
21    #[arg(
22        long,
23        global = true,
24        env = "MII_MEMORY_EMBEDDINGS",
25        value_name = "PATH"
26    )]
27    embeddings: Option<PathBuf>,
28
29    #[command(subcommand)]
30    command: Command,
31}
32
33#[derive(Debug, Subcommand)]
34enum Command {
35    Set(SetCommand),
36    Get(GetCommand),
37    ListTags(ListTagsCommand),
38    Alert(AlertCommand),
39    Alerts(AlertsCommand),
40    Mcp,
41    Explorer(ExplorerCommand),
42}
43
44#[derive(Debug, Parser)]
45struct AlertCommand {
46    #[command(subcommand)]
47    command: AlertSubcommand,
48}
49
50#[derive(Debug, Subcommand)]
51enum AlertSubcommand {
52    Set(AlertSetCommand),
53}
54
55#[derive(Debug, Parser)]
56struct AlertSetCommand {
57    session_ref: String,
58    content: String,
59}
60
61#[derive(Debug, Parser)]
62struct AlertsCommand {
63    session_ref: String,
64}
65
66#[derive(Debug, Parser)]
67struct SetCommand {
68    content: String,
69
70    #[arg(long, value_enum, default_value_t = MemoryMode::Global)]
71    mode: MemoryMode,
72
73    #[arg(value_name = "MODE_REF")]
74    mode_ref: Option<String>,
75
76    #[arg(short = 't', long = "tag", required = true)]
77    tags: Vec<String>,
78
79    #[arg(
80        long = "expiration-condition",
81        value_names = ["CONDITION", "VALUE"],
82        num_args = 2
83    )]
84    expiration: Option<Vec<String>>,
85
86    #[arg(long)]
87    metadata: Option<String>,
88}
89
90#[derive(Debug, Parser)]
91struct GetCommand {
92    query: String,
93
94    #[arg(short = 't', long = "tag", alias = "p-tag")]
95    positive_tags: Vec<String>,
96
97    #[arg(long = "n-tag")]
98    negative_tags: Vec<String>,
99
100    #[arg(long, default_value_t = 10)]
101    limit: usize,
102
103    #[arg(long, default_value_t = 0)]
104    offset: usize,
105
106    #[arg(long, value_enum)]
107    mode: Option<MemoryMode>,
108
109    #[arg(long)]
110    mode_ref: Option<String>,
111}
112
113#[derive(Debug, Parser)]
114struct ListTagsCommand {
115    #[arg(long)]
116    filter: Option<String>,
117
118    #[arg(long)]
119    json: bool,
120}
121
122#[derive(Debug, Parser)]
123struct ExplorerCommand {
124    #[arg(long, default_value = "127.0.0.1")]
125    host: String,
126
127    #[arg(long, default_value_t = 4117)]
128    port: u16,
129}
130
131#[derive(Debug, Serialize)]
132struct SetOutput {
133    id: i64,
134}
135
136pub fn run() -> Result<()> {
137    let cli = Cli::parse();
138    if let Some(embeddings_path) = cli.embeddings {
139        configure_embeddings_path(embeddings_path)?;
140    }
141
142    let database_path = cli.db.unwrap_or_else(default_database_path);
143    let mut store = MemoryStore::open(&database_path)?;
144
145    match cli.command {
146        Command::Set(command) => {
147            let id = store.set(command.try_into()?)?;
148            println!("{}", serde_json::to_string(&SetOutput { id })?);
149        }
150        Command::Get(command) => {
151            for result in store.get(command.into())? {
152                println!("{}", serde_json::to_string(&result)?);
153            }
154        }
155        Command::ListTags(command) => {
156            let tags = store.list_tags(command.filter.as_deref())?;
157            for tag in tags {
158                if command.json {
159                    println!("{}", serde_json::to_string(&tag)?);
160                } else {
161                    println!("{}", tag.tag);
162                }
163            }
164        }
165        Command::Alert(command) => match command.command {
166            AlertSubcommand::Set(command) => {
167                let id = store.set_alert(command.session_ref, command.content)?;
168                println!("{}", serde_json::to_string(&SetOutput { id })?);
169            }
170        },
171        Command::Alerts(command) => {
172            for alert in store.get_alerts(command.session_ref)? {
173                println!("{}", serde_json::to_string(&alert)?);
174            }
175        }
176        Command::Mcp => {
177            let input = BufReader::new(io::stdin().lock());
178            let output = io::stdout().lock();
179            mcp::serve(store, input, output)?;
180        }
181        Command::Explorer(command) => {
182            drop(store);
183            explorer::serve(database_path, &command.host, command.port)?;
184        }
185    }
186
187    Ok(())
188}
189
190impl TryFrom<SetCommand> for SetMemory {
191    type Error = anyhow::Error;
192
193    fn try_from(command: SetCommand) -> Result<Self> {
194        let (expiration_condition, expiration_value) = parse_expiration_pair(command.expiration)?;
195
196        Ok(Self {
197            content: command.content,
198            mode: command.mode,
199            mode_ref: command.mode_ref,
200            tags: command.tags,
201            expiration_condition,
202            expiration_value,
203            metadata: command.metadata,
204        })
205    }
206}
207
208impl From<GetCommand> for SearchOptions {
209    fn from(command: GetCommand) -> Self {
210        Self {
211            query: command.query,
212            positive_tags: command.positive_tags,
213            negative_tags: command.negative_tags,
214            limit: command.limit,
215            offset: command.offset,
216            mode: command.mode,
217            mode_ref: command.mode_ref,
218        }
219    }
220}
221
222pub fn parse_expiration_pair(
223    expiration: Option<Vec<String>>,
224) -> Result<(Option<ExpirationCondition>, Option<String>)> {
225    let Some(expiration) = expiration else {
226        return Ok((None, None));
227    };
228
229    let [condition, value] = expiration.as_slice() else {
230        bail!("--expiration-condition expects CONDITION and VALUE");
231    };
232
233    Ok((
234        Some(
235            ExpirationCondition::from_str(condition)
236                .with_context(|| format!("invalid expiration condition {condition}"))?,
237        ),
238        Some(value.clone()),
239    ))
240}