1use std::path::PathBuf;
15use std::sync::Arc;
16use std::sync::atomic::{AtomicU64, Ordering};
17use std::time::Duration;
18
19use axum::Json;
20use axum::extract::{Path, State};
21use axum::routing::{get, post};
22use axum::{Router, response::IntoResponse};
23use quiver_cluster::ShardMap;
24use serde::{Deserialize, Serialize};
25use serde_json::{Value, json};
26use tokio::net::TcpListener;
27use tokio::sync::RwLock;
28
29use crate::Config;
30use crate::error::Error;
31
32const MIGRATION_GRACE: Duration = Duration::from_secs(3);
37const COPY_PAGE: usize = 1_000;
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
50#[serde(default)]
51pub struct AutoscaleConfig {
52 pub enabled: bool,
54 pub high_water_points: u64,
57 pub standby_urls: Vec<String>,
59 pub interval_secs: u64,
61 pub cooldown_secs: u64,
64 pub max_shards: usize,
66}
67
68impl Default for AutoscaleConfig {
69 fn default() -> Self {
70 Self {
71 enabled: false,
72 high_water_points: 0,
73 standby_urls: Vec::new(),
74 interval_secs: 30,
75 cooldown_secs: 300,
76 max_shards: 0,
77 }
78 }
79}
80
81#[derive(Serialize, Deserialize)]
84struct Persisted {
85 next_id: u64,
86 map: ShardMap,
87}
88
89struct CoordinatorState {
91 map: RwLock<ShardMap>,
92 next_id: AtomicU64,
93 path: Option<PathBuf>,
95 http: reqwest::Client,
97 shard_key: Option<String>,
100 autoscale: AutoscaleConfig,
102 standby: tokio::sync::Mutex<Vec<String>>,
104 last_scale: tokio::sync::Mutex<Option<std::time::Instant>>,
106}
107
108impl CoordinatorState {
109 fn bootstrap(config: &Config) -> Result<Self, Error> {
112 let path = config.coordinator_state.clone();
113 if let Some(p) = &path
114 && p.exists()
115 {
116 let bytes = std::fs::read(p).map_err(Error::Io)?;
117 let persisted: Persisted = serde_json::from_slice(&bytes)
118 .map_err(|e| Error::Config(format!("coordinator state {p:?}: {e}")))?;
119 return Ok(Self {
120 map: RwLock::new(persisted.map),
121 next_id: AtomicU64::new(persisted.next_id),
122 path,
123 http: reqwest::Client::new(),
124 shard_key: config.cluster_shard_key.clone(),
125 autoscale: config.autoscale.clone(),
126 standby: tokio::sync::Mutex::new(config.autoscale.standby_urls.clone()),
127 last_scale: tokio::sync::Mutex::new(None),
128 });
129 }
130 let map = build_seed_map(config)?;
131 let next_id = map.len() as u64; let state = Self {
133 map: RwLock::new(map),
134 next_id: AtomicU64::new(next_id),
135 path,
136 http: reqwest::Client::new(),
137 shard_key: config.cluster_shard_key.clone(),
138 autoscale: config.autoscale.clone(),
139 standby: tokio::sync::Mutex::new(config.autoscale.standby_urls.clone()),
140 last_scale: tokio::sync::Mutex::new(None),
141 };
142 Ok(state)
143 }
144
145 fn persist(&self, map: &ShardMap) -> Result<(), Error> {
148 let Some(p) = &self.path else { return Ok(()) };
149 let persisted = Persisted {
150 next_id: self.next_id.load(Ordering::SeqCst),
151 map: map.clone(),
152 };
153 let bytes = serde_json::to_vec_pretty(&persisted)
154 .map_err(|e| Error::Internal(format!("serialize coordinator state: {e}")))?;
155 std::fs::write(p, bytes).map_err(Error::Io)
156 }
157
158 fn auth(&self, rb: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
161 match &self.shard_key {
162 Some(k) => rb.bearer_auth(k),
163 None => rb,
164 }
165 }
166
167 async fn send_json(
169 &self,
170 method: reqwest::Method,
171 url: &str,
172 body: Value,
173 ) -> Result<Value, Error> {
174 let resp = self
175 .auth(self.http.request(method, url).json(&body))
176 .send()
177 .await
178 .map_err(|e| Error::Internal(format!("shard {url} unreachable: {e}")))?;
179 let status = resp.status();
180 let text = resp.text().await.unwrap_or_default();
181 if !status.is_success() {
182 return Err(Error::Internal(format!(
183 "shard {url} returned {status}: {text}"
184 )));
185 }
186 Ok(serde_json::from_str(&text).unwrap_or(Value::Null))
187 }
188
189 async fn list_collection_metas(&self, url: &str) -> Result<Vec<Value>, Error> {
191 let body = self
192 .send_json(
193 reqwest::Method::GET,
194 &format!("{url}/v1/collections"),
195 Value::Null,
196 )
197 .await?;
198 Ok(body.as_array().cloned().unwrap_or_default())
199 }
200
201 async fn ensure_collection(&self, new_url: &str, dto: &Value) -> Result<(), Error> {
204 let name = dto["name"].as_str().unwrap_or_default();
205 let exists = self
206 .auth(self.http.get(format!("{new_url}/v1/collections/{name}")))
207 .send()
208 .await
209 .map(|r| r.status().is_success())
210 .unwrap_or(false);
211 if exists {
212 return Ok(());
213 }
214 let mut body = json!({
215 "name": dto["name"],
216 "dim": dto["dim"],
217 "metric": dto["metric"],
218 "index": dto["index"],
219 });
220 for k in ["pq_subspaces", "filterable", "vector_encryption"] {
221 if let Some(v) = dto.get(k) {
222 body[k] = v.clone();
223 }
224 }
225 self.send_json(
226 reqwest::Method::POST,
227 &format!("{new_url}/v1/collections"),
228 body,
229 )
230 .await
231 .map(|_| ())
232 }
233
234 async fn fetch_page(
236 &self,
237 url: &str,
238 collection: &str,
239 offset: usize,
240 with_vector: bool,
241 ) -> Result<Vec<Value>, Error> {
242 let body = self
243 .send_json(
244 reqwest::Method::POST,
245 &format!("{url}/v1/collections/{collection}/fetch"),
246 json!({"offset": offset, "limit": COPY_PAGE, "with_payload": true, "with_vector": with_vector}),
247 )
248 .await?;
249 Ok(body["points"].as_array().cloned().unwrap_or_default())
250 }
251
252 async fn copy_slice(
257 &self,
258 donor: &str,
259 new_url: &str,
260 collection: &str,
261 map: &ShardMap,
262 new_id: u64,
263 ) -> Result<(), Error> {
264 let mut offset = 0usize;
265 loop {
266 let page = self.fetch_page(donor, collection, offset, true).await?;
267 let n = page.len();
268 for pt in &page {
269 let Some(id) = pt["id"].as_str() else {
270 continue;
271 };
272 if map.shard_for(id).id != new_id {
273 continue;
274 }
275 let get = format!("{new_url}/v1/collections/{collection}/points/{id}");
276 let present = self
277 .auth(self.http.get(&get))
278 .send()
279 .await
280 .map(|r| r.status().is_success())
281 .unwrap_or(false);
282 if present {
283 continue;
284 }
285 self.send_json(
286 reqwest::Method::POST,
287 &format!("{new_url}/v1/collections/{collection}/points"),
288 json!({"points": [{"id": id, "vector": pt["vector"], "payload": pt["payload"]}]}),
289 )
290 .await?;
291 }
292 offset += n;
293 if n < COPY_PAGE {
294 return Ok(());
295 }
296 }
297 }
298
299 async fn drop_slice(
301 &self,
302 donor: &str,
303 collection: &str,
304 map: &ShardMap,
305 new_id: u64,
306 ) -> Result<(), Error> {
307 let mut offset = 0usize;
308 let mut ids: Vec<String> = Vec::new();
309 loop {
310 let page = self.fetch_page(donor, collection, offset, false).await?;
311 let n = page.len();
312 for pt in &page {
313 if let Some(id) = pt["id"].as_str()
314 && map.shard_for(id).id == new_id
315 {
316 ids.push(id.to_owned());
317 }
318 }
319 offset += n;
320 if n < COPY_PAGE {
321 break;
322 }
323 }
324 for chunk in ids.chunks(COPY_PAGE) {
325 self.send_json(
326 reqwest::Method::DELETE,
327 &format!("{donor}/v1/collections/{collection}/points"),
328 json!({ "ids": chunk }),
329 )
330 .await?;
331 }
332 Ok(())
333 }
334
335 async fn run_migration(&self, new_id: u64) -> Result<(), Error> {
340 tokio::time::sleep(MIGRATION_GRACE).await;
341 let map = self.map.read().await.clone();
344 let new_url = map
345 .shards()
346 .iter()
347 .find(|s| s.id == new_id)
348 .map(|s| s.primary_url.clone())
349 .ok_or_else(|| Error::Internal("joining shard left the map".into()))?;
350 let donors: Vec<String> = map
351 .active_shards()
352 .iter()
353 .map(|s| s.primary_url.clone())
354 .collect();
355 let donor0 = donors
356 .first()
357 .ok_or_else(|| Error::Internal("no donor for migration".into()))?;
358 let collections = self.list_collection_metas(donor0).await?;
359 if collections
360 .iter()
361 .any(|c| c["multivector"].as_bool().unwrap_or(false))
362 {
363 return Err(Error::BadRequest(
364 "auto-migration does not yet support multivector collections".into(),
365 ));
366 }
367 for c in &collections {
368 let name = c["name"].as_str().unwrap_or_default().to_owned();
369 self.ensure_collection(&new_url, c).await?;
370 for donor in &donors {
371 self.copy_slice(donor, &new_url, &name, &map, new_id)
372 .await?;
373 }
374 }
375 {
377 let mut m = self.map.write().await;
378 m.promote(new_id)
379 .map_err(|e| Error::BadRequest(e.to_string()))?;
380 self.persist(&m)?;
381 }
382 tokio::time::sleep(MIGRATION_GRACE).await;
383 for c in &collections {
384 let name = c["name"].as_str().unwrap_or_default().to_owned();
385 for donor in &donors {
386 self.drop_slice(donor, &name, &map, new_id).await?;
387 }
388 }
389 tracing::info!(shard = new_id, "cluster migration complete");
390 Ok(())
391 }
392
393 async fn grow_shard(
400 self: &Arc<Self>,
401 primary_url: String,
402 replica_urls: Vec<String>,
403 ) -> Result<ShardMap, Error> {
404 let id = self.next_id.fetch_add(1, Ordering::SeqCst);
405 let snapshot = {
406 let mut map = self.map.write().await;
407 map.add_joining_shard(id, &primary_url, replica_urls)
408 .map_err(|e| Error::BadRequest(e.to_string()))?;
409 self.persist(&map)?;
410 map.clone()
411 };
412 let bg = self.clone();
413 tokio::spawn(async move {
414 if let Err(e) = bg.run_migration(id).await {
415 tracing::error!(shard = id, error = %e, "cluster migration failed; reverting the join");
416 let mut map = bg.map.write().await;
417 let _ = map.remove_shard(id);
418 let _ = bg.persist(&map);
419 }
420 });
421 Ok(snapshot)
422 }
423
424 async fn shard_points(&self, primary_url: &str) -> u64 {
429 let Ok(metas) = self.list_collection_metas(primary_url).await else {
430 return 0;
431 };
432 metas.iter().filter_map(|c| c["count"].as_u64()).sum()
433 }
434
435 async fn maybe_scale_out(self: &Arc<Self>) {
439 let cfg = &self.autoscale;
440 if !cfg.enabled || cfg.high_water_points == 0 {
441 return;
442 }
443 if let Some(t) = *self.last_scale.lock().await
444 && t.elapsed() < Duration::from_secs(cfg.cooldown_secs)
445 {
446 return; }
448 let (active, migrating) = {
449 let map = self.map.read().await;
450 let active: Vec<String> = map
451 .active_shards()
452 .iter()
453 .map(|s| s.primary_url.clone())
454 .collect();
455 let migrating = map.shards().iter().any(|s| map.is_joining(s.id));
456 (active, migrating)
457 };
458 if migrating {
459 return; }
461 if cfg.max_shards != 0 && active.len() >= cfg.max_shards {
462 return;
463 }
464 let mut max_points = 0u64;
465 for url in &active {
466 max_points = max_points.max(self.shard_points(url).await);
467 }
468 if max_points <= cfg.high_water_points {
469 return;
470 }
471 let standby = self.standby.lock().await.pop();
472 let Some(url) = standby else {
473 tracing::warn!(
474 max_points,
475 "autoscale: high-water exceeded but the standby pool is empty"
476 );
477 return;
478 };
479 tracing::info!(max_points, standby = %url, "autoscale: growing the cluster");
480 match self.grow_shard(url.clone(), Vec::new()).await {
481 Ok(_) => *self.last_scale.lock().await = Some(std::time::Instant::now()),
482 Err(e) => {
483 tracing::error!(error = %e, "autoscale grow failed; returning the standby to the pool");
484 self.standby.lock().await.push(url);
485 }
486 }
487 }
488}
489
490fn build_seed_map(config: &Config) -> Result<ShardMap, Error> {
493 let mut map = ShardMap::from_urls(config.cluster_shards.clone())
494 .map_err(|e| Error::Config(e.to_string()))?;
495 for spec in &config.cluster_replicas {
496 let (id, url) = spec.split_once('=').ok_or_else(|| {
497 Error::Config(format!("replica entry {spec:?} must be \"<id>=<url>\""))
498 })?;
499 let id: u64 = id
500 .trim()
501 .parse()
502 .map_err(|_| Error::Config(format!("replica entry {spec:?} has a non-numeric id")))?;
503 map.add_replica(id, url)
504 .map_err(|e| Error::Config(e.to_string()))?;
505 }
506 Ok(map)
507}
508
509pub async fn serve_coordinator(config: Config, listener: TcpListener) -> Result<(), Error> {
513 let state = Arc::new(CoordinatorState::bootstrap(&config)?);
514 let n = state.map.read().await.len();
515 tracing::info!(shards = n, "quiver cluster coordinator started");
516
517 if state.autoscale.enabled {
521 let st = state.clone();
522 let interval = Duration::from_secs(state.autoscale.interval_secs.max(1));
523 tracing::info!(
524 interval_secs = interval.as_secs(),
525 high_water = state.autoscale.high_water_points,
526 standby = state.autoscale.standby_urls.len(),
527 "autoscale policy enabled (scale-out)"
528 );
529 tokio::spawn(async move {
530 loop {
531 tokio::time::sleep(interval).await;
532 st.maybe_scale_out().await;
533 }
534 });
535 }
536 let app = Router::new()
537 .route("/healthz", get(healthz))
538 .route("/readyz", get(healthz))
539 .route("/cluster/map", get(get_map))
540 .route("/cluster/shards", post(add_shard))
541 .route("/cluster/shards/grow", post(grow))
542 .route("/cluster/shards/joining", post(add_joining_shard))
543 .route("/cluster/shards/{id}/promote", post(promote_shard))
544 .route("/cluster/shards/{id}", axum::routing::delete(remove_shard))
545 .route("/cluster/health", get(health))
546 .with_state(state);
547 axum::serve(listener, app).await.map_err(Error::Io)
548}
549
550async fn healthz() -> &'static str {
551 "ok"
552}
553
554async fn get_map(State(st): State<Arc<CoordinatorState>>) -> Json<ShardMap> {
556 Json(st.map.read().await.clone())
557}
558
559#[derive(Deserialize)]
560struct AddShardReq {
561 primary_url: String,
562 #[serde(default)]
563 replica_urls: Vec<String>,
564}
565
566async fn add_shard(
569 State(st): State<Arc<CoordinatorState>>,
570 Json(req): Json<AddShardReq>,
571) -> Result<Json<ShardMap>, Error> {
572 let id = st.next_id.fetch_add(1, Ordering::SeqCst);
573 let mut map = st.map.write().await;
574 map.add_shard(id, req.primary_url, req.replica_urls)
575 .map_err(|e| Error::BadRequest(e.to_string()))?;
576 st.persist(&map)?;
577 Ok(Json(map.clone()))
578}
579
580async fn grow(
586 State(st): State<Arc<CoordinatorState>>,
587 Json(req): Json<AddShardReq>,
588) -> Result<Json<ShardMap>, Error> {
589 let snapshot = st.grow_shard(req.primary_url, req.replica_urls).await?;
590 Ok(Json(snapshot))
591}
592
593async fn add_joining_shard(
597 State(st): State<Arc<CoordinatorState>>,
598 Json(req): Json<AddShardReq>,
599) -> Result<Json<ShardMap>, Error> {
600 let id = st.next_id.fetch_add(1, Ordering::SeqCst);
601 let mut map = st.map.write().await;
602 map.add_joining_shard(id, req.primary_url, req.replica_urls)
603 .map_err(|e| Error::BadRequest(e.to_string()))?;
604 st.persist(&map)?;
605 Ok(Json(map.clone()))
606}
607
608async fn promote_shard(
611 State(st): State<Arc<CoordinatorState>>,
612 Path(id): Path<u64>,
613) -> Result<Json<ShardMap>, Error> {
614 let mut map = st.map.write().await;
615 map.promote(id)
616 .map_err(|e| Error::BadRequest(e.to_string()))?;
617 st.persist(&map)?;
618 Ok(Json(map.clone()))
619}
620
621async fn remove_shard(
625 State(st): State<Arc<CoordinatorState>>,
626 Path(id): Path<u64>,
627) -> Result<Json<ShardMap>, Error> {
628 let mut map = st.map.write().await;
629 map.remove_shard(id)
630 .map_err(|e| Error::BadRequest(e.to_string()))?;
631 st.persist(&map)?;
632 Ok(Json(map.clone()))
633}
634
635async fn health(State(st): State<Arc<CoordinatorState>>) -> impl IntoResponse {
638 let shards = st.map.read().await.shards().to_vec();
639 let mut out = serde_json::Map::new();
640 for shard in shards {
641 let url = format!("{}/healthz", shard.primary_url.trim_end_matches('/'));
642 let up = matches!(st.http.get(&url).send().await, Ok(r) if r.status().is_success());
643 out.insert(shard.id.to_string(), json!(up));
644 }
645 Json(Value::Object(out))
646}
647
648#[cfg(test)]
649mod tests {
650 use super::*;
651
652 fn config(shards: Vec<&str>, replicas: Vec<&str>) -> Config {
653 Config {
654 cluster_shards: shards.into_iter().map(String::from).collect(),
655 cluster_replicas: replicas.into_iter().map(String::from).collect(),
656 ..Default::default()
657 }
658 }
659
660 #[test]
661 fn build_seed_map_assigns_ids_and_attaches_replicas() {
662 let map = build_seed_map(&config(
663 vec!["http://s0:6333", "http://s1:6333"],
664 vec!["1=http://s1b:6333"],
665 ))
666 .unwrap();
667 assert_eq!(map.version(), 0);
668 assert_eq!(
669 map.shards().iter().map(|s| s.id).collect::<Vec<_>>(),
670 [0, 1]
671 );
672 assert_eq!(map.shards()[1].replica_urls, ["http://s1b:6333"]);
673 }
674
675 #[test]
676 fn build_seed_map_rejects_malformed_replica_specs() {
677 let err = |replicas| match build_seed_map(&config(vec!["http://s0"], replicas)) {
678 Err(Error::Config(_)) => {}
679 other => panic!("expected a Config error, got {:?}", other.map(|_| "Ok")),
680 };
681 err(vec!["http://no-equals"]); err(vec!["x=http://s"]); err(vec!["9=http://s"]); }
685
686 #[test]
687 fn persisted_state_round_trips() {
688 let mut map = ShardMap::from_urls(["http://s0"]).unwrap();
689 map.add_shard(1, "http://s1", vec![]).unwrap();
690 let json = serde_json::to_vec(&Persisted { next_id: 2, map }).unwrap();
691 let back: Persisted = serde_json::from_slice(&json).unwrap();
692 assert_eq!(back.next_id, 2);
693 assert_eq!(back.map.version(), 1);
694 assert_eq!(back.map.len(), 2);
695 }
696}