Skip to main content

mcp_postgres/actions/
pgvector.rs

1use crate::errors::Result as MCPResult;
2use serde_json::{Value, json};
3use tokio_postgres::Client;
4
5pub async fn list_vector_columns(client: &Client, _params: &Option<&Value>) -> MCPResult<Value> {
6    let rows = client
7        .query(
8            "SELECT c.table_schema, c.table_name, c.column_name, c.data_type,
9                    e.udt_name
10             FROM information_schema.columns c
11             JOIN information_schema.element_types e ON (c.table_catalog, c.table_schema, c.table_name, c.column_name, c.dtd_identifier)
12             WHERE c.data_type = 'USER-DEFINED'
13               AND e.udt_name = 'vector'
14             ORDER BY c.table_schema, c.table_name, c.ordinal_position",
15            &[],
16        )
17        .await
18        ?;
19
20    let columns: Vec<Value> = rows
21        .iter()
22        .map(|row| {
23            json!({
24                "schema": row.get::<_, String>(0),
25                "table": row.get::<_, String>(1),
26                "column": row.get::<_, String>(2),
27                "type": "vector",
28            })
29        })
30        .collect();
31
32    Ok(json!({ "vector_columns": columns }))
33}
34
35pub async fn vector_search(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
36    let table = params
37        .as_ref()
38        .and_then(|p| p.get("table").and_then(|v| v.as_str()))
39        .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'table'".into()))?;
40    let column = params
41        .as_ref()
42        .and_then(|p| p.get("column").and_then(|v| v.as_str()))
43        .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'column'".into()))?;
44    let vector = params
45        .as_ref()
46        .and_then(|p| p.get("vector").and_then(|v| v.as_str()))
47        .ok_or_else(|| {
48            crate::errors::MCPError::InvalidParams(
49                "Missing 'vector' parameter (e.g. '[0.1,0.2,0.3]')".into(),
50            )
51        })?;
52    let limit = params
53        .as_ref()
54        .and_then(|p| p.get("limit").and_then(|v| v.as_i64()))
55        .unwrap_or(10);
56    let schema = params
57        .as_ref()
58        .and_then(|p| p.get("schema").and_then(|v| v.as_str()))
59        .unwrap_or("public");
60    let select_cols = params
61        .as_ref()
62        .and_then(|p| p.get("select").and_then(|v| v.as_str()))
63        .unwrap_or("*");
64    let distance = params
65        .as_ref()
66        .and_then(|p| p.get("distance").and_then(|v| v.as_str()))
67        .unwrap_or("cosine");
68
69    let operator = match distance {
70        "l2" | "euclidean" => "<->",
71        "inner" | "ip" => "<#>",
72        _ => "<=>",
73    };
74
75    let qcol = crate::validation::quote_ident(column);
76    let qual = format!(
77        "{}.{}",
78        crate::validation::quote_ident(schema),
79        crate::validation::quote_ident(table)
80    );
81    let sql = format!(
82        "SELECT {}, {qcol} {operator} '{vector}' AS distance
83         FROM {qual}
84         ORDER BY {qcol} {operator} '{vector}'
85         LIMIT {}",
86        select_cols,
87        limit.min(1000)
88    );
89
90    let rows = client.query(&sql, &[]).await?;
91
92    let mut results = Vec::new();
93    for row in &rows {
94        let mut obj = serde_json::Map::new();
95        for (i, col) in row.columns().iter().enumerate() {
96            let name = col.name();
97            if let Ok(v) = row.try_get::<_, Value>(i) {
98                obj.insert(name.to_string(), v);
99            } else if let Ok(v) = row.try_get::<_, String>(i) {
100                obj.insert(name.to_string(), Value::String(v));
101            } else if let Ok(v) = row.try_get::<_, i64>(i) {
102                obj.insert(name.to_string(), json!(v));
103            } else if let Ok(v) = row.try_get::<_, f64>(i) {
104                obj.insert(name.to_string(), json!(v));
105            } else if let Ok(v) = row.try_get::<_, bool>(i) {
106                obj.insert(name.to_string(), json!(v));
107            } else if let Ok(v) = row.try_get::<_, Option<String>>(i) {
108                obj.insert(
109                    name.to_string(),
110                    v.map(Value::String).unwrap_or(Value::Null),
111                );
112            }
113        }
114        results.push(Value::Object(obj));
115    }
116
117    Ok(json!({
118        "results": results,
119        "count": results.len(),
120        "distance_metric": distance,
121    }))
122}
123
124pub async fn create_vector_index(client: &Client, params: &Option<&Value>) -> MCPResult<Value> {
125    let table = params
126        .as_ref()
127        .and_then(|p| p.get("table").and_then(|v| v.as_str()))
128        .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'table'".into()))?;
129    let column = params
130        .as_ref()
131        .and_then(|p| p.get("column").and_then(|v| v.as_str()))
132        .ok_or_else(|| crate::errors::MCPError::InvalidParams("Missing 'column'".into()))?;
133    let index_type = params
134        .as_ref()
135        .and_then(|p| p.get("index_type").and_then(|v| v.as_str()))
136        .unwrap_or("hnsw");
137    let distance = params
138        .as_ref()
139        .and_then(|p| p.get("distance").and_then(|v| v.as_str()))
140        .unwrap_or("cosine");
141    let schema = params
142        .as_ref()
143        .and_then(|p| p.get("schema").and_then(|v| v.as_str()))
144        .unwrap_or("public");
145
146    let distance_op = match distance {
147        "l2" | "euclidean" => "vector_l2_ops",
148        "inner" | "ip" => "vector_ip_ops",
149        _ => "vector_cosine_ops",
150    };
151
152    let index_name = format!("idx_{}_{}_{}", table, column, index_type);
153
154    let q_schema = crate::validation::quote_ident(schema);
155    let q_table = crate::validation::quote_ident(table);
156    let q_column = crate::validation::quote_ident(column);
157    let sql = match index_type {
158        "ivfflat" => {
159            let lists = params
160                .as_ref()
161                .and_then(|p| p.get("lists").and_then(|v| v.as_i64()))
162                .unwrap_or(100);
163            format!(
164                "CREATE INDEX \"{index_name}\" ON {q_schema}.{q_table} USING ivfflat ({q_column} {distance_op}) WITH (lists = {lists})"
165            )
166        }
167        _ => {
168            format!(
169                "CREATE INDEX \"{index_name}\" ON {q_schema}.{q_table} USING hnsw ({q_column} {distance_op})"
170            )
171        }
172    };
173
174    client.execute(&sql, &[]).await?;
175    Ok(json!({ "success": true, "index": index_name, "sql": sql }))
176}