use bytes::Bytes;
use resp_async::response::RespError;
use resp_async::{Cmd, State, Value};
use sqlx::Row;
use std::cmp::Ordering;
use std::sync::Arc;
use crate::handlers::util::{arg_as_bytes, arg_as_f64, arg_as_i64, wrong_arity, wrong_type};
use crate::state::{AppState, SessionHandle, now_ms};
use crate::storage::{TYPE_ZSET, delete_key_in_tx, load_meta, map_sql_err};
const GEO_STEP: u8 = 26;
const BASE32: &[u8; 32] = b"0123456789bcdefghjkmnpqrstuvwxyz";
pub async fn geoadd(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
if cmd.args.len() < 4 || (cmd.args.len() - 1) % 3 != 0 {
return Err(wrong_arity("GEOADD"));
}
let auth = session.auth().await.ok_or(RespError::NoAuth)?;
state.pools.touch(&auth.user);
let key = arg_as_bytes(&cmd.args[0])?;
let now = now_ms() as i64;
if let Some(meta) = load_meta(auth.pool.as_ref(), key, now).await?
&& meta.r_type != TYPE_ZSET
{
return Ok(wrong_type());
}
let mut entries = Vec::new();
let mut i = 1;
while i < cmd.args.len() {
let lon = arg_as_f64(&cmd.args[i])?;
let lat = arg_as_f64(&cmd.args[i + 1])?;
let member = arg_as_bytes(&cmd.args[i + 2])?.clone();
let hash = encode_geohash(lon, lat)
.ok_or_else(|| RespError::invalid_data("ERR invalid longitude,latitude"))?;
entries.push((member, hash as f64));
i += 3;
}
let mut existing = 0i64;
if !entries.is_empty() {
let mut qb = sqlx::QueryBuilder::new("SELECT member FROM redis_zset WHERE r_key = ");
qb.push_bind(key.as_ref());
qb.push(" AND member IN (");
let mut separated = qb.separated(", ");
for (member, _) in &entries {
separated.push_bind(member.as_ref());
}
qb.push(")");
let rows = qb
.build()
.fetch_all(auth.pool.as_ref())
.await
.map_err(map_sql_err)?;
existing = rows.len() as i64;
}
let added = entries.len() as i64 - existing;
let mut tx = auth.pool.begin().await.map_err(map_sql_err)?;
if added > 0 {
sqlx::query(
"INSERT INTO redis_kv (r_key, r_type, r_len, expires_at_ms) \
VALUES (?, ?, ?, NULL) \
ON DUPLICATE KEY UPDATE r_len = r_len + VALUES(r_len)",
)
.bind(key.as_ref())
.bind(TYPE_ZSET)
.bind(added)
.execute(&mut *tx)
.await
.map_err(map_sql_err)?;
}
let mut qb = sqlx::QueryBuilder::new("INSERT INTO redis_zset (r_key, member, score) ");
qb.push_values(entries.iter(), |mut row, (member, score)| {
row.push_bind(key.as_ref());
row.push_bind(member.as_ref());
row.push_bind(*score);
});
qb.push(" ON DUPLICATE KEY UPDATE score = VALUES(score)");
qb.build().execute(&mut *tx).await.map_err(map_sql_err)?;
tx.commit().await.map_err(map_sql_err)?;
Ok(Value::Integer(added.max(0)))
}
pub async fn geopos(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
if cmd.args.len() < 2 {
return Err(wrong_arity("GEOPOS"));
}
let auth = session.auth().await.ok_or(RespError::NoAuth)?;
state.pools.touch(&auth.user);
let key = arg_as_bytes(&cmd.args[0])?;
let now = now_ms() as i64;
let meta = load_meta(auth.pool.as_ref(), key, now).await?;
let Some(meta) = meta else {
return Ok(Value::Array(vec![Value::Null; cmd.args.len() - 1]));
};
if meta.r_type != TYPE_ZSET {
return Ok(wrong_type());
}
let mut qb = sqlx::QueryBuilder::new("SELECT member, score FROM redis_zset WHERE r_key = ");
qb.push_bind(key.as_ref());
qb.push(" AND member IN (");
let mut separated = qb.separated(", ");
for arg in cmd.args.iter().skip(1) {
let member = arg_as_bytes(arg)?;
separated.push_bind(member.as_ref());
}
qb.push(")");
let rows = qb
.build()
.fetch_all(auth.pool.as_ref())
.await
.map_err(map_sql_err)?;
let mut map = std::collections::HashMap::new();
for row in rows {
let member: Vec<u8> = row.try_get("member").map_err(map_sql_err)?;
let score: f64 = row.try_get("score").map_err(map_sql_err)?;
map.insert(member, score);
}
let mut out = Vec::with_capacity(cmd.args.len() - 1);
for arg in cmd.args.iter().skip(1) {
let member = arg_as_bytes(arg)?;
match map.get(member.as_ref()) {
Some(score) => {
let (lon, lat) = decode_geohash(*score as u64);
out.push(Value::Array(vec![
Value::Bulk(Bytes::from(lon.to_string())),
Value::Bulk(Bytes::from(lat.to_string())),
]));
}
None => out.push(Value::Null),
}
}
Ok(Value::Array(out))
}
pub async fn geodist(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
if cmd.args.len() < 3 || cmd.args.len() > 4 {
return Err(wrong_arity("GEODIST"));
}
let auth = session.auth().await.ok_or(RespError::NoAuth)?;
state.pools.touch(&auth.user);
let key = arg_as_bytes(&cmd.args[0])?;
let member1 = arg_as_bytes(&cmd.args[1])?;
let member2 = arg_as_bytes(&cmd.args[2])?;
let unit = if cmd.args.len() == 4 {
arg_as_bytes(&cmd.args[3])?.clone()
} else {
Bytes::from_static(b"m")
};
let now = now_ms() as i64;
let meta = load_meta(auth.pool.as_ref(), key, now).await?;
let Some(meta) = meta else {
return Ok(Value::Null);
};
if meta.r_type != TYPE_ZSET {
return Ok(wrong_type());
}
let pos1 = fetch_member_pos(auth.pool.as_ref(), key, member1).await?;
let pos2 = fetch_member_pos(auth.pool.as_ref(), key, member2).await?;
let (lon1, lat1) = match pos1 {
Some(v) => v,
None => return Ok(Value::Null),
};
let (lon2, lat2) = match pos2 {
Some(v) => v,
None => return Ok(Value::Null),
};
let dist = haversine(lon1, lat1, lon2, lat2);
let dist = from_meters(dist, unit.as_ref())?;
Ok(Value::Bulk(Bytes::from(dist.to_string())))
}
pub async fn geohash(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
if cmd.args.len() < 2 {
return Err(wrong_arity("GEOHASH"));
}
let auth = session.auth().await.ok_or(RespError::NoAuth)?;
state.pools.touch(&auth.user);
let key = arg_as_bytes(&cmd.args[0])?;
let now = now_ms() as i64;
let meta = load_meta(auth.pool.as_ref(), key, now).await?;
let Some(meta) = meta else {
return Ok(Value::Array(vec![Value::Null; cmd.args.len() - 1]));
};
if meta.r_type != TYPE_ZSET {
return Ok(wrong_type());
}
let mut qb = sqlx::QueryBuilder::new("SELECT member, score FROM redis_zset WHERE r_key = ");
qb.push_bind(key.as_ref());
qb.push(" AND member IN (");
let mut separated = qb.separated(", ");
for arg in cmd.args.iter().skip(1) {
let member = arg_as_bytes(arg)?;
separated.push_bind(member.as_ref());
}
qb.push(")");
let rows = qb
.build()
.fetch_all(auth.pool.as_ref())
.await
.map_err(map_sql_err)?;
let mut map = std::collections::HashMap::new();
for row in rows {
let member: Vec<u8> = row.try_get("member").map_err(map_sql_err)?;
let score: f64 = row.try_get("score").map_err(map_sql_err)?;
map.insert(member, score);
}
let mut out = Vec::with_capacity(cmd.args.len() - 1);
for arg in cmd.args.iter().skip(1) {
let member = arg_as_bytes(arg)?;
match map.get(member.as_ref()) {
Some(score) => {
let hash = *score as u64;
out.push(Value::Bulk(Bytes::from(geohash_string(hash))));
}
None => out.push(Value::Null),
}
}
Ok(Value::Array(out))
}
pub async fn georadius(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
georadius_inner(cmd, state, session, false).await
}
pub async fn georadiusbymember(
Cmd(cmd): Cmd,
State(state): State<AppState>,
SessionHandle(session): SessionHandle,
) -> Result<Value, RespError> {
georadius_inner(cmd, state, session, true).await
}
async fn georadius_inner(
cmd: resp_async::Command,
state: Arc<AppState>,
session: Arc<crate::state::Session>,
by_member: bool,
) -> Result<Value, RespError> {
let auth = session.auth().await.ok_or(RespError::NoAuth)?;
state.pools.touch(&auth.user);
let mut i = 1;
let key = arg_as_bytes(&cmd.args[0])?;
let (lon, lat) = if by_member {
let member = arg_as_bytes(&cmd.args[1])?;
i += 1;
match fetch_member_pos(auth.pool.as_ref(), key, member).await? {
Some(pos) => pos,
None => return Ok(Value::Array(Vec::new())),
}
} else {
let lon = arg_as_f64(&cmd.args[1])?;
let lat = arg_as_f64(&cmd.args[2])?;
i += 2;
(lon, lat)
};
let radius = arg_as_f64(&cmd.args[i])?;
let unit = arg_as_bytes(&cmd.args[i + 1])?;
i += 2;
let mut with_coord = false;
let mut with_dist = false;
let mut with_hash = false;
let mut count: Option<usize> = None;
let mut order: Option<Ordering> = None;
let mut store_key: Option<Bytes> = None;
let mut store_dist = false;
while i < cmd.args.len() {
let mut token = arg_as_bytes(&cmd.args[i])?.to_vec();
for b in &mut token {
b.make_ascii_uppercase();
}
match token.as_slice() {
b"WITHCOORD" => with_coord = true,
b"WITHDIST" => with_dist = true,
b"WITHHASH" => with_hash = true,
b"COUNT" => {
let value = arg_as_i64(&cmd.args[i + 1])?;
if value > 0 {
count = Some(value as usize);
}
i += 1;
}
b"ASC" => order = Some(Ordering::Less),
b"DESC" => order = Some(Ordering::Greater),
b"STORE" => {
store_key = Some(arg_as_bytes(&cmd.args[i + 1])?.clone());
store_dist = false;
i += 1;
}
b"STOREDIST" => {
store_key = Some(arg_as_bytes(&cmd.args[i + 1])?.clone());
store_dist = true;
i += 1;
}
_ => return Err(RespError::invalid_data("ERR syntax error")),
}
i += 1;
}
let radius_m = to_meters(radius, unit.as_ref())?;
let now = now_ms() as i64;
let meta = load_meta(auth.pool.as_ref(), key, now).await?;
let Some(meta) = meta else {
return Ok(Value::Array(Vec::new()));
};
if meta.r_type != TYPE_ZSET {
return Ok(wrong_type());
}
let rows = sqlx::query("SELECT member, score FROM redis_zset WHERE r_key = ?")
.bind(key.as_ref())
.fetch_all(auth.pool.as_ref())
.await
.map_err(map_sql_err)?;
let mut results = Vec::new();
for row in rows {
let member: Vec<u8> = row.try_get("member").map_err(map_sql_err)?;
let score: f64 = row.try_get("score").map_err(map_sql_err)?;
let (mlon, mlat) = decode_geohash(score as u64);
let dist = haversine(lon, lat, mlon, mlat);
if dist <= radius_m {
results.push(ResultRow {
member: Bytes::from(member),
score,
distance: dist,
coord: (mlon, mlat),
});
}
}
if let Some(ordering) = order {
results.sort_by(|a, b| {
if ordering == Ordering::Less {
a.distance
.partial_cmp(&b.distance)
.unwrap_or(Ordering::Equal)
} else {
b.distance
.partial_cmp(&a.distance)
.unwrap_or(Ordering::Equal)
}
});
}
if let Some(limit) = count {
results.truncate(limit);
}
if let Some(dest) = store_key {
let mut tx = auth.pool.begin().await.map_err(map_sql_err)?;
delete_key_in_tx(&mut tx, &dest).await?;
if !results.is_empty() {
sqlx::query(
"INSERT INTO redis_kv (r_key, r_type, r_len, expires_at_ms) VALUES (?, ?, ?, NULL)",
)
.bind(dest.as_ref())
.bind(TYPE_ZSET)
.bind(results.len() as i64)
.execute(&mut *tx)
.await
.map_err(map_sql_err)?;
let mut qb = sqlx::QueryBuilder::new("INSERT INTO redis_zset (r_key, member, score) ");
qb.push_values(results.iter(), |mut row, item| {
row.push_bind(dest.as_ref());
row.push_bind(item.member.as_ref());
let score = if store_dist {
item.distance
} else {
item.score
};
row.push_bind(score);
});
qb.build().execute(&mut *tx).await.map_err(map_sql_err)?;
}
tx.commit().await.map_err(map_sql_err)?;
return Ok(Value::Integer(results.len() as i64));
}
let mut out = Vec::new();
for item in results {
if with_coord || with_dist || with_hash {
let mut entry = vec![Value::Bulk(item.member.clone())];
if with_dist {
entry.push(Value::Bulk(Bytes::from(
from_meters(item.distance, unit.as_ref())?.to_string(),
)));
}
if with_hash {
entry.push(Value::Bulk(Bytes::from((item.score as u64).to_string())));
}
if with_coord {
entry.push(Value::Array(vec![
Value::Bulk(Bytes::from(item.coord.0.to_string())),
Value::Bulk(Bytes::from(item.coord.1.to_string())),
]));
}
out.push(Value::Array(entry));
} else {
out.push(Value::Bulk(item.member));
}
}
Ok(Value::Array(out))
}
struct ResultRow {
member: Bytes,
score: f64,
distance: f64,
coord: (f64, f64),
}
async fn fetch_member_pos(
pool: &sqlx::MySqlPool,
key: &Bytes,
member: &Bytes,
) -> Result<Option<(f64, f64)>, RespError> {
let row = sqlx::query("SELECT score FROM redis_zset WHERE r_key = ? AND member = ?")
.bind(key.as_ref())
.bind(member.as_ref())
.fetch_optional(pool)
.await
.map_err(map_sql_err)?;
let Some(row) = row else {
return Ok(None);
};
let score: f64 = row.try_get("score").map_err(map_sql_err)?;
Ok(Some(decode_geohash(score as u64)))
}
fn encode_geohash(lon: f64, lat: f64) -> Option<u64> {
if !(-180.0..=180.0).contains(&lon) || !(-90.0..=90.0).contains(&lat) {
return None;
}
let mut min_lon = -180.0;
let mut max_lon = 180.0;
let mut min_lat = -90.0;
let mut max_lat = 90.0;
let mut hash = 0u64;
for _ in 0..GEO_STEP {
let mid = (min_lon + max_lon) / 2.0;
if lon >= mid {
hash = (hash << 1) | 1;
min_lon = mid;
} else {
hash <<= 1;
max_lon = mid;
}
let mid = (min_lat + max_lat) / 2.0;
if lat >= mid {
hash = (hash << 1) | 1;
min_lat = mid;
} else {
hash <<= 1;
max_lat = mid;
}
}
Some(hash)
}
fn decode_geohash(hash: u64) -> (f64, f64) {
let mut min_lon = -180.0;
let mut max_lon = 180.0;
let mut min_lat = -90.0;
let mut max_lat = 90.0;
for i in (0..GEO_STEP).rev() {
let bit_lon = (hash >> (i * 2 + 1)) & 1;
let bit_lat = (hash >> (i * 2)) & 1;
let mid_lon = (min_lon + max_lon) / 2.0;
if bit_lon == 1 {
min_lon = mid_lon;
} else {
max_lon = mid_lon;
}
let mid_lat = (min_lat + max_lat) / 2.0;
if bit_lat == 1 {
min_lat = mid_lat;
} else {
max_lat = mid_lat;
}
}
((min_lon + max_lon) / 2.0, (min_lat + max_lat) / 2.0)
}
fn geohash_string(hash: u64) -> String {
let mut out = String::with_capacity(11);
let value = hash << 3;
for i in (0..11).rev() {
let idx = ((value >> (i * 5)) & 0x1f) as usize;
out.push(BASE32[idx] as char);
}
out
}
fn haversine(lon1: f64, lat1: f64, lon2: f64, lat2: f64) -> f64 {
let rad = std::f64::consts::PI / 180.0;
let dlon = (lon2 - lon1) * rad;
let dlat = (lat2 - lat1) * rad;
let a = (dlat / 2.0).sin().powi(2)
+ lat1.to_radians().cos() * lat2.to_radians().cos() * (dlon / 2.0).sin().powi(2);
let c = 2.0 * a.sqrt().atan2((1.0 - a).sqrt());
6372797.6 * c
}
fn from_meters(value: f64, unit: &[u8]) -> Result<f64, RespError> {
match unit {
b"m" => Ok(value),
b"km" => Ok(value / 1000.0),
b"mi" => Ok(value / 1609.34),
b"ft" => Ok(value / 0.3048),
_ => Err(RespError::invalid_data("ERR invalid unit")),
}
}
fn to_meters(value: f64, unit: &[u8]) -> Result<f64, RespError> {
match unit {
b"m" => Ok(value),
b"km" => Ok(value * 1000.0),
b"mi" => Ok(value * 1609.34),
b"ft" => Ok(value * 0.3048),
_ => Err(RespError::invalid_data("ERR invalid unit")),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn geohash_roundtrip_origin() {
let hash = encode_geohash(0.0, 0.0).unwrap();
let (lon, lat) = decode_geohash(hash);
assert!(lon.abs() < 1.0);
assert!(lat.abs() < 1.0);
}
}