1use std::path::PathBuf;
15use std::sync::Arc;
16use std::sync::atomic::{AtomicU64, Ordering};
17use std::time::Duration;
18
19use axum::extract::{Path, Request, State};
20use axum::http::StatusCode;
21use axum::http::header::AUTHORIZATION;
22use axum::middleware::{self, Next};
23use axum::response::Response;
24use axum::routing::{get, post};
25use axum::{Extension, Json, Router, response::IntoResponse};
26use quiver_cluster::ShardMap;
27use serde::{Deserialize, Serialize};
28use serde_json::{Value, json};
29use tokio::net::TcpListener;
30use tokio::sync::RwLock;
31
32use crate::Config;
33use crate::auth::{self, Action, ApiKey, Principal};
34use crate::error::Error;
35
36const MIGRATION_GRACE: Duration = Duration::from_secs(3);
41const COPY_PAGE: usize = 1_000;
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
54#[serde(default)]
55pub struct AutoscaleConfig {
56 pub enabled: bool,
58 pub high_water_points: u64,
61 pub standby_urls: Vec<String>,
63 pub interval_secs: u64,
65 pub cooldown_secs: u64,
68 pub max_shards: usize,
70}
71
72impl Default for AutoscaleConfig {
73 fn default() -> Self {
74 Self {
75 enabled: false,
76 high_water_points: 0,
77 standby_urls: Vec::new(),
78 interval_secs: 30,
79 cooldown_secs: 300,
80 max_shards: 0,
81 }
82 }
83}
84
85#[derive(Serialize, Deserialize)]
88struct Persisted {
89 next_id: u64,
90 map: ShardMap,
91}
92
93struct CoordinatorState {
95 map: RwLock<ShardMap>,
96 next_id: AtomicU64,
97 path: Option<PathBuf>,
99 http: reqwest::Client,
101 shard_key: Option<String>,
104 keys: Arc<Vec<ApiKey>>,
110 autoscale: AutoscaleConfig,
112 standby: tokio::sync::Mutex<Vec<String>>,
114 last_scale: tokio::sync::Mutex<Option<std::time::Instant>>,
116}
117
118impl CoordinatorState {
119 fn bootstrap(config: &Config) -> Result<Self, Error> {
122 let path = config.coordinator_state.clone();
123 if let Some(p) = &path
124 && p.exists()
125 {
126 let bytes = std::fs::read(p).map_err(Error::Io)?;
127 let persisted: Persisted = serde_json::from_slice(&bytes)
128 .map_err(|e| Error::Config(format!("coordinator state {p:?}: {e}")))?;
129 return Ok(Self {
130 map: RwLock::new(persisted.map),
131 next_id: AtomicU64::new(persisted.next_id),
132 path,
133 http: reqwest::Client::new(),
134 shard_key: config.cluster_shard_key.clone(),
135 keys: Arc::new(config.api_keys.clone()),
136 autoscale: config.autoscale.clone(),
137 standby: tokio::sync::Mutex::new(config.autoscale.standby_urls.clone()),
138 last_scale: tokio::sync::Mutex::new(None),
139 });
140 }
141 let map = build_seed_map(config)?;
142 let next_id = map.len() as u64; let state = Self {
144 map: RwLock::new(map),
145 next_id: AtomicU64::new(next_id),
146 path,
147 http: reqwest::Client::new(),
148 shard_key: config.cluster_shard_key.clone(),
149 keys: Arc::new(config.api_keys.clone()),
150 autoscale: config.autoscale.clone(),
151 standby: tokio::sync::Mutex::new(config.autoscale.standby_urls.clone()),
152 last_scale: tokio::sync::Mutex::new(None),
153 };
154 Ok(state)
155 }
156
157 fn persist(&self, map: &ShardMap) -> Result<(), Error> {
160 let Some(p) = &self.path else { return Ok(()) };
161 let persisted = Persisted {
162 next_id: self.next_id.load(Ordering::SeqCst),
163 map: map.clone(),
164 };
165 let bytes = serde_json::to_vec_pretty(&persisted)
166 .map_err(|e| Error::Internal(format!("serialize coordinator state: {e}")))?;
167 std::fs::write(p, bytes).map_err(Error::Io)
168 }
169
170 fn auth(&self, rb: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
173 match &self.shard_key {
174 Some(k) => rb.bearer_auth(k),
175 None => rb,
176 }
177 }
178
179 async fn send_json(
181 &self,
182 method: reqwest::Method,
183 url: &str,
184 body: Value,
185 ) -> Result<Value, Error> {
186 let resp = self
187 .auth(self.http.request(method, url).json(&body))
188 .send()
189 .await
190 .map_err(|e| Error::Internal(format!("shard {url} unreachable: {e}")))?;
191 let status = resp.status();
192 let text = resp.text().await.unwrap_or_default();
193 if !status.is_success() {
194 return Err(Error::Internal(format!(
195 "shard {url} returned {status}: {text}"
196 )));
197 }
198 Ok(serde_json::from_str(&text).unwrap_or(Value::Null))
199 }
200
201 async fn list_collection_metas(&self, url: &str) -> Result<Vec<Value>, Error> {
203 let body = self
204 .send_json(
205 reqwest::Method::GET,
206 &format!("{url}/v1/collections"),
207 Value::Null,
208 )
209 .await?;
210 Ok(body.as_array().cloned().unwrap_or_default())
211 }
212
213 async fn ensure_collection(&self, new_url: &str, dto: &Value) -> Result<(), Error> {
216 let name = dto["name"].as_str().unwrap_or_default();
217 let exists = self
218 .auth(self.http.get(format!("{new_url}/v1/collections/{name}")))
219 .send()
220 .await
221 .map(|r| r.status().is_success())
222 .unwrap_or(false);
223 if exists {
224 return Ok(());
225 }
226 let mut body = json!({
227 "name": dto["name"],
228 "dim": dto["dim"],
229 "metric": dto["metric"],
230 "index": dto["index"],
231 });
232 for k in ["pq_subspaces", "filterable", "vector_encryption"] {
233 if let Some(v) = dto.get(k) {
234 body[k] = v.clone();
235 }
236 }
237 self.send_json(
238 reqwest::Method::POST,
239 &format!("{new_url}/v1/collections"),
240 body,
241 )
242 .await
243 .map(|_| ())
244 }
245
246 async fn fetch_page(
248 &self,
249 url: &str,
250 collection: &str,
251 offset: usize,
252 with_vector: bool,
253 ) -> Result<Vec<Value>, Error> {
254 let body = self
255 .send_json(
256 reqwest::Method::POST,
257 &format!("{url}/v1/collections/{collection}/fetch"),
258 json!({"offset": offset, "limit": COPY_PAGE, "with_payload": true, "with_vector": with_vector}),
259 )
260 .await?;
261 Ok(body["points"].as_array().cloned().unwrap_or_default())
262 }
263
264 async fn copy_slice(
269 &self,
270 donor: &str,
271 new_url: &str,
272 collection: &str,
273 map: &ShardMap,
274 new_id: u64,
275 ) -> Result<(), Error> {
276 let mut offset = 0usize;
277 loop {
278 let page = self.fetch_page(donor, collection, offset, true).await?;
279 let n = page.len();
280 for pt in &page {
281 let Some(id) = pt["id"].as_str() else {
282 continue;
283 };
284 if map.shard_for(id).id != new_id {
285 continue;
286 }
287 let get = format!("{new_url}/v1/collections/{collection}/points/{id}");
288 let present = self
289 .auth(self.http.get(&get))
290 .send()
291 .await
292 .map(|r| r.status().is_success())
293 .unwrap_or(false);
294 if present {
295 continue;
296 }
297 self.send_json(
298 reqwest::Method::POST,
299 &format!("{new_url}/v1/collections/{collection}/points"),
300 json!({"points": [{"id": id, "vector": pt["vector"], "payload": pt["payload"]}]}),
301 )
302 .await?;
303 }
304 offset += n;
305 if n < COPY_PAGE {
306 return Ok(());
307 }
308 }
309 }
310
311 async fn drop_slice(
313 &self,
314 donor: &str,
315 collection: &str,
316 map: &ShardMap,
317 new_id: u64,
318 ) -> Result<(), Error> {
319 let mut offset = 0usize;
320 let mut ids: Vec<String> = Vec::new();
321 loop {
322 let page = self.fetch_page(donor, collection, offset, false).await?;
323 let n = page.len();
324 for pt in &page {
325 if let Some(id) = pt["id"].as_str()
326 && map.shard_for(id).id == new_id
327 {
328 ids.push(id.to_owned());
329 }
330 }
331 offset += n;
332 if n < COPY_PAGE {
333 break;
334 }
335 }
336 for chunk in ids.chunks(COPY_PAGE) {
337 self.send_json(
338 reqwest::Method::DELETE,
339 &format!("{donor}/v1/collections/{collection}/points"),
340 json!({ "ids": chunk }),
341 )
342 .await?;
343 }
344 Ok(())
345 }
346
347 async fn run_migration(&self, new_id: u64) -> Result<(), Error> {
352 tokio::time::sleep(MIGRATION_GRACE).await;
353 let map = self.map.read().await.clone();
356 let new_url = map
357 .shards()
358 .iter()
359 .find(|s| s.id == new_id)
360 .map(|s| s.primary_url.clone())
361 .ok_or_else(|| Error::Internal("joining shard left the map".into()))?;
362 let donors: Vec<String> = map
363 .active_shards()
364 .iter()
365 .map(|s| s.primary_url.clone())
366 .collect();
367 let donor0 = donors
368 .first()
369 .ok_or_else(|| Error::Internal("no donor for migration".into()))?;
370 let collections = self.list_collection_metas(donor0).await?;
371 if collections
372 .iter()
373 .any(|c| c["multivector"].as_bool().unwrap_or(false))
374 {
375 return Err(Error::BadRequest(
376 "auto-migration does not yet support multivector collections".into(),
377 ));
378 }
379 for c in &collections {
380 let name = c["name"].as_str().unwrap_or_default().to_owned();
381 self.ensure_collection(&new_url, c).await?;
382 for donor in &donors {
383 self.copy_slice(donor, &new_url, &name, &map, new_id)
384 .await?;
385 }
386 }
387 {
389 let mut m = self.map.write().await;
390 m.promote(new_id)
391 .map_err(|e| Error::BadRequest(e.to_string()))?;
392 self.persist(&m)?;
393 }
394 tokio::time::sleep(MIGRATION_GRACE).await;
395 for c in &collections {
396 let name = c["name"].as_str().unwrap_or_default().to_owned();
397 for donor in &donors {
398 self.drop_slice(donor, &name, &map, new_id).await?;
399 }
400 }
401 tracing::info!(shard = new_id, "cluster migration complete");
402 Ok(())
403 }
404
405 async fn grow_shard(
412 self: &Arc<Self>,
413 primary_url: String,
414 replica_urls: Vec<String>,
415 ) -> Result<ShardMap, Error> {
416 let id = self.next_id.fetch_add(1, Ordering::SeqCst);
417 let snapshot = {
418 let mut map = self.map.write().await;
419 map.add_joining_shard(id, &primary_url, replica_urls)
420 .map_err(|e| Error::BadRequest(e.to_string()))?;
421 self.persist(&map)?;
422 map.clone()
423 };
424 let bg = self.clone();
425 tokio::spawn(async move {
426 if let Err(e) = bg.run_migration(id).await {
427 tracing::error!(shard = id, error = %e, "cluster migration failed; reverting the join");
428 let mut map = bg.map.write().await;
429 let _ = map.remove_shard(id);
430 let _ = bg.persist(&map);
431 }
432 });
433 Ok(snapshot)
434 }
435
436 async fn shard_points(&self, primary_url: &str) -> u64 {
441 let Ok(metas) = self.list_collection_metas(primary_url).await else {
442 return 0;
443 };
444 metas.iter().filter_map(|c| c["count"].as_u64()).sum()
445 }
446
447 async fn maybe_scale_out(self: &Arc<Self>) {
451 let cfg = &self.autoscale;
452 if !cfg.enabled || cfg.high_water_points == 0 {
453 return;
454 }
455 if let Some(t) = *self.last_scale.lock().await
456 && t.elapsed() < Duration::from_secs(cfg.cooldown_secs)
457 {
458 return; }
460 let (active, migrating) = {
461 let map = self.map.read().await;
462 let active: Vec<String> = map
463 .active_shards()
464 .iter()
465 .map(|s| s.primary_url.clone())
466 .collect();
467 let migrating = map.shards().iter().any(|s| map.is_joining(s.id));
468 (active, migrating)
469 };
470 if migrating {
471 return; }
473 if cfg.max_shards != 0 && active.len() >= cfg.max_shards {
474 return;
475 }
476 let mut max_points = 0u64;
477 for url in &active {
478 max_points = max_points.max(self.shard_points(url).await);
479 }
480 if max_points <= cfg.high_water_points {
481 return;
482 }
483 let standby = self.standby.lock().await.pop();
484 let Some(url) = standby else {
485 tracing::warn!(
486 max_points,
487 "autoscale: high-water exceeded but the standby pool is empty"
488 );
489 return;
490 };
491 tracing::info!(max_points, standby = %url, "autoscale: growing the cluster");
492 match self.grow_shard(url.clone(), Vec::new()).await {
493 Ok(_) => *self.last_scale.lock().await = Some(std::time::Instant::now()),
494 Err(e) => {
495 tracing::error!(error = %e, "autoscale grow failed; returning the standby to the pool");
496 self.standby.lock().await.push(url);
497 }
498 }
499 }
500}
501
502fn build_seed_map(config: &Config) -> Result<ShardMap, Error> {
505 let mut map = ShardMap::from_urls(config.cluster_shards.clone())
506 .map_err(|e| Error::Config(e.to_string()))?;
507 for spec in &config.cluster_replicas {
508 let (id, url) = spec.split_once('=').ok_or_else(|| {
509 Error::Config(format!("replica entry {spec:?} must be \"<id>=<url>\""))
510 })?;
511 let id: u64 = id
512 .trim()
513 .parse()
514 .map_err(|_| Error::Config(format!("replica entry {spec:?} has a non-numeric id")))?;
515 map.add_replica(id, url)
516 .map_err(|e| Error::Config(e.to_string()))?;
517 }
518 Ok(map)
519}
520
521pub async fn serve_coordinator(config: Config, listener: TcpListener) -> Result<(), Error> {
525 let state = Arc::new(CoordinatorState::bootstrap(&config)?);
526 let n = state.map.read().await.len();
527 tracing::info!(shards = n, "quiver cluster coordinator started");
528
529 if state.autoscale.enabled {
533 let st = state.clone();
534 let interval = Duration::from_secs(state.autoscale.interval_secs.max(1));
535 tracing::info!(
536 interval_secs = interval.as_secs(),
537 high_water = state.autoscale.high_water_points,
538 standby = state.autoscale.standby_urls.len(),
539 "autoscale policy enabled (scale-out)"
540 );
541 tokio::spawn(async move {
542 loop {
543 tokio::time::sleep(interval).await;
544 st.maybe_scale_out().await;
545 }
546 });
547 }
548 let authed = Router::new()
554 .route("/cluster/map", get(get_map))
555 .route("/cluster/shards", post(add_shard))
556 .route("/cluster/shards/grow", post(grow))
557 .route("/cluster/shards/joining", post(add_joining_shard))
558 .route("/cluster/shards/{id}/promote", post(promote_shard))
559 .route("/cluster/shards/{id}", axum::routing::delete(remove_shard))
560 .route("/cluster/health", get(health))
561 .layer(middleware::from_fn_with_state(
562 state.clone(),
563 coordinator_auth,
564 ))
565 .with_state(state);
566 let app = Router::new()
567 .route("/healthz", get(healthz))
568 .route("/readyz", get(healthz))
569 .merge(authed);
570 axum::serve(listener, app).await.map_err(Error::Io)
571}
572
573async fn coordinator_auth(
578 State(st): State<Arc<CoordinatorState>>,
579 mut request: Request,
580 next: Next,
581) -> Response {
582 let presented = request
583 .headers()
584 .get(AUTHORIZATION)
585 .and_then(|v| v.to_str().ok())
586 .and_then(|v| {
587 v.strip_prefix("Bearer ")
588 .or_else(|| v.strip_prefix("bearer "))
589 })
590 .map(str::to_owned);
591 match auth::authenticate(&st.keys, presented.as_deref()) {
592 Some(principal) => {
593 request.extensions_mut().insert(principal);
594 next.run(request).await
595 }
596 None => {
597 let body = json!({
598 "type": "about:blank",
599 "title": "Unauthorized",
600 "status": 401,
601 "detail": "missing or invalid API key",
602 });
603 (StatusCode::UNAUTHORIZED, Json(body)).into_response()
604 }
605 }
606}
607
608async fn healthz() -> &'static str {
609 "ok"
610}
611
612async fn get_map(State(st): State<Arc<CoordinatorState>>) -> Json<ShardMap> {
614 Json(st.map.read().await.clone())
615}
616
617#[derive(Deserialize)]
618struct AddShardReq {
619 primary_url: String,
620 #[serde(default)]
621 replica_urls: Vec<String>,
622}
623
624async fn add_shard(
627 State(st): State<Arc<CoordinatorState>>,
628 Extension(principal): Extension<Principal>,
629 Json(req): Json<AddShardReq>,
630) -> Result<Json<ShardMap>, Error> {
631 principal.require(Action::Admin, None)?;
632 let id = st.next_id.fetch_add(1, Ordering::SeqCst);
633 let mut map = st.map.write().await;
634 map.add_shard(id, req.primary_url, req.replica_urls)
635 .map_err(|e| Error::BadRequest(e.to_string()))?;
636 st.persist(&map)?;
637 Ok(Json(map.clone()))
638}
639
640async fn grow(
646 State(st): State<Arc<CoordinatorState>>,
647 Extension(principal): Extension<Principal>,
648 Json(req): Json<AddShardReq>,
649) -> Result<Json<ShardMap>, Error> {
650 principal.require(Action::Admin, None)?;
651 let snapshot = st.grow_shard(req.primary_url, req.replica_urls).await?;
652 Ok(Json(snapshot))
653}
654
655async fn add_joining_shard(
659 State(st): State<Arc<CoordinatorState>>,
660 Extension(principal): Extension<Principal>,
661 Json(req): Json<AddShardReq>,
662) -> Result<Json<ShardMap>, Error> {
663 principal.require(Action::Admin, None)?;
664 let id = st.next_id.fetch_add(1, Ordering::SeqCst);
665 let mut map = st.map.write().await;
666 map.add_joining_shard(id, req.primary_url, req.replica_urls)
667 .map_err(|e| Error::BadRequest(e.to_string()))?;
668 st.persist(&map)?;
669 Ok(Json(map.clone()))
670}
671
672async fn promote_shard(
675 State(st): State<Arc<CoordinatorState>>,
676 Extension(principal): Extension<Principal>,
677 Path(id): Path<u64>,
678) -> Result<Json<ShardMap>, Error> {
679 principal.require(Action::Admin, None)?;
680 let mut map = st.map.write().await;
681 map.promote(id)
682 .map_err(|e| Error::BadRequest(e.to_string()))?;
683 st.persist(&map)?;
684 Ok(Json(map.clone()))
685}
686
687async fn remove_shard(
691 State(st): State<Arc<CoordinatorState>>,
692 Extension(principal): Extension<Principal>,
693 Path(id): Path<u64>,
694) -> Result<Json<ShardMap>, Error> {
695 principal.require(Action::Admin, None)?;
696 let mut map = st.map.write().await;
697 map.remove_shard(id)
698 .map_err(|e| Error::BadRequest(e.to_string()))?;
699 st.persist(&map)?;
700 Ok(Json(map.clone()))
701}
702
703async fn health(State(st): State<Arc<CoordinatorState>>) -> impl IntoResponse {
706 let shards = st.map.read().await.shards().to_vec();
707 let mut out = serde_json::Map::new();
708 for shard in shards {
709 let url = format!("{}/healthz", shard.primary_url.trim_end_matches('/'));
710 let up = matches!(st.http.get(&url).send().await, Ok(r) if r.status().is_success());
711 out.insert(shard.id.to_string(), json!(up));
712 }
713 Json(Value::Object(out))
714}
715
716#[cfg(test)]
717mod tests {
718 use super::*;
719
720 fn config(shards: Vec<&str>, replicas: Vec<&str>) -> Config {
721 Config {
722 cluster_shards: shards.into_iter().map(String::from).collect(),
723 cluster_replicas: replicas.into_iter().map(String::from).collect(),
724 ..Default::default()
725 }
726 }
727
728 #[test]
729 fn build_seed_map_assigns_ids_and_attaches_replicas() {
730 let map = build_seed_map(&config(
731 vec!["http://s0:6333", "http://s1:6333"],
732 vec!["1=http://s1b:6333"],
733 ))
734 .unwrap();
735 assert_eq!(map.version(), 0);
736 assert_eq!(
737 map.shards().iter().map(|s| s.id).collect::<Vec<_>>(),
738 [0, 1]
739 );
740 assert_eq!(map.shards()[1].replica_urls, ["http://s1b:6333"]);
741 }
742
743 #[test]
744 fn build_seed_map_rejects_malformed_replica_specs() {
745 let err = |replicas| match build_seed_map(&config(vec!["http://s0"], replicas)) {
746 Err(Error::Config(_)) => {}
747 other => panic!("expected a Config error, got {:?}", other.map(|_| "Ok")),
748 };
749 err(vec!["http://no-equals"]); err(vec!["x=http://s"]); err(vec!["9=http://s"]); }
753
754 #[test]
755 fn persisted_state_round_trips() {
756 let mut map = ShardMap::from_urls(["http://s0"]).unwrap();
757 map.add_shard(1, "http://s1", vec![]).unwrap();
758 let json = serde_json::to_vec(&Persisted { next_id: 2, map }).unwrap();
759 let back: Persisted = serde_json::from_slice(&json).unwrap();
760 assert_eq!(back.next_id, 2);
761 assert_eq!(back.map.version(), 1);
762 assert_eq!(back.map.len(), 2);
763 }
764}