use async_trait::async_trait;
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION};
use reqwest::{Client, StatusCode};
use crate::discovery::Discovery;
use crate::error::{Result, StoreError};
use crate::proto::AgentRecord;
pub struct RegistryFabric {
client: Client,
base: String,
}
impl RegistryFabric {
pub fn new(base_url: impl Into<String>) -> Result<Self> {
Self::build(base_url.into(), None)
}
pub fn with_token(base_url: impl Into<String>, token: impl Into<String>) -> Result<Self> {
Self::build(base_url.into(), Some(token.into()))
}
fn build(base_url: String, token: Option<String>) -> Result<Self> {
let mut headers = HeaderMap::new();
if let Some(token) = token {
let mut value = HeaderValue::from_str(&format!("Bearer {token}"))
.map_err(|_| StoreError::Network("invalid bearer token".into()))?;
value.set_sensitive(true); headers.insert(AUTHORIZATION, value);
}
let client = Client::builder()
.default_headers(headers)
.build()
.map_err(net_err)?;
Ok(Self {
client,
base: base_url.trim_end_matches('/').to_string(),
})
}
fn agents_url(&self) -> String {
format!("{}/v1/agents", self.base)
}
fn agent_url(&self, agent_id: &str) -> String {
format!("{}/v1/agents/{}", self.base, agent_id)
}
}
#[async_trait]
impl Discovery for RegistryFabric {
async fn publish(&self, rec: &AgentRecord) -> Result<()> {
let resp = self
.client
.post(self.agents_url())
.json(rec)
.send()
.await
.map_err(net_err)?;
ensure_ok(resp).await.map(|_| ())
}
async fn resolve(&self, agent_id: &str) -> Result<Option<AgentRecord>> {
let resp = self
.client
.get(self.agent_url(agent_id))
.send()
.await
.map_err(net_err)?;
if resp.status() == StatusCode::NOT_FOUND {
return Ok(None);
}
let resp = ensure_ok(resp).await?;
let rec = resp.json::<AgentRecord>().await.map_err(net_err)?;
Ok(Some(rec))
}
async fn discover(&self) -> Result<Vec<AgentRecord>> {
let resp = self
.client
.get(self.agents_url())
.send()
.await
.map_err(net_err)?;
let resp = ensure_ok(resp).await?;
let recs = resp.json::<Vec<AgentRecord>>().await.map_err(net_err)?;
Ok(recs)
}
async fn withdraw(&self, agent_id: &str) -> Result<()> {
let resp = self
.client
.delete(self.agent_url(agent_id))
.send()
.await
.map_err(net_err)?;
if resp.status() == StatusCode::NOT_FOUND {
return Ok(());
}
ensure_ok(resp).await.map(|_| ())
}
async fn gc(&self) -> Result<usize> {
Ok(0)
}
}
fn net_err(e: reqwest::Error) -> StoreError {
StoreError::Network(e.to_string())
}
async fn ensure_ok(resp: reqwest::Response) -> Result<reqwest::Response> {
let status = resp.status();
if status.is_success() {
return Ok(resp);
}
if status == StatusCode::UNAUTHORIZED || status == StatusCode::FORBIDDEN {
return Err(StoreError::Network(format!(
"registry rejected the request: {status} (missing or invalid credential)"
)));
}
Err(StoreError::Network(format!(
"registry returned an error status: {status}"
)))
}
#[cfg(test)]
mod tests {
use super::*;
use chrono::{Duration, Utc};
use wiremock::matchers::{header, method, path, query_param};
use wiremock::{Mock, MockServer, ResponseTemplate};
fn rec(id: &str) -> AgentRecord {
AgentRecord {
agent_id: id.into(),
role: "service".into(),
labels: vec!["gpu".into()],
endpoint: "ws://10.0.0.2:8443".into(),
pid: 7,
version: "1".into(),
started_at: Utc::now(),
lease_expires_at: Utc::now() + Duration::seconds(30),
}
}
#[tokio::test]
async fn publish_resolve_discover_withdraw_round_trip_over_http() {
let server = MockServer::start().await;
let r = rec("a");
Mock::given(method("POST"))
.and(path("/v1/agents"))
.and(header("authorization", "Bearer t-secret"))
.respond_with(ResponseTemplate::new(200).set_body_json(&r))
.mount(&server)
.await;
Mock::given(method("GET"))
.and(path("/v1/agents/a"))
.and(header("authorization", "Bearer t-secret"))
.respond_with(ResponseTemplate::new(200).set_body_json(&r))
.mount(&server)
.await;
Mock::given(method("GET"))
.and(path("/v1/agents"))
.and(header("authorization", "Bearer t-secret"))
.respond_with(ResponseTemplate::new(200).set_body_json(vec![r.clone()]))
.mount(&server)
.await;
Mock::given(method("DELETE"))
.and(path("/v1/agents/a"))
.respond_with(ResponseTemplate::new(200))
.mount(&server)
.await;
let fab = RegistryFabric::with_token(server.uri(), "t-secret").unwrap();
let disc: &dyn Discovery = &fab;
disc.publish(&r).await.unwrap();
let got = disc.resolve("a").await.unwrap().unwrap();
assert_eq!(got.agent_id, "a");
assert_eq!(disc.discover().await.unwrap().len(), 1);
disc.withdraw("a").await.unwrap();
assert_eq!(disc.gc().await.unwrap(), 0);
}
#[tokio::test]
async fn resolve_missing_is_none_and_withdraw_missing_is_ok() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/v1/agents/ghost"))
.respond_with(ResponseTemplate::new(404))
.mount(&server)
.await;
Mock::given(method("DELETE"))
.and(path("/v1/agents/ghost"))
.respond_with(ResponseTemplate::new(404))
.mount(&server)
.await;
let fab = RegistryFabric::new(server.uri()).unwrap();
assert!(fab.resolve("ghost").await.unwrap().is_none());
fab.withdraw("ghost").await.unwrap(); }
#[tokio::test]
async fn discover_with_role_filter_query_is_sent() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/v1/agents"))
.respond_with(ResponseTemplate::new(200).set_body_json(Vec::<AgentRecord>::new()))
.mount(&server)
.await;
let fab = RegistryFabric::new(server.uri()).unwrap();
assert!(fab.discover().await.unwrap().is_empty());
}
#[tokio::test]
async fn unauthorized_maps_to_network_error_without_leaking_token() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/agents"))
.respond_with(ResponseTemplate::new(401))
.mount(&server)
.await;
let fab = RegistryFabric::with_token(server.uri(), "super-secret-token").unwrap();
let err = fab.publish(&rec("x")).await.unwrap_err();
let msg = err.to_string();
assert!(msg.contains("401"), "expected 401 in: {msg}");
assert!(
!msg.contains("super-secret-token"),
"token must never leak into the error: {msg}"
);
}
#[tokio::test]
async fn trailing_slash_in_base_is_normalized() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/v1/agents"))
.respond_with(ResponseTemplate::new(200).set_body_json(Vec::<AgentRecord>::new()))
.mount(&server)
.await;
let fab = RegistryFabric::new(format!("{}/", server.uri())).unwrap();
assert!(fab.discover().await.unwrap().is_empty());
}
#[tokio::test]
async fn query_param_route_is_exercisable() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/v1/agents"))
.and(query_param("role", "service"))
.respond_with(ResponseTemplate::new(200).set_body_json(Vec::<AgentRecord>::new()))
.mount(&server)
.await;
let client = reqwest::Client::new();
let resp = client
.get(format!("{}/v1/agents?role=service", server.uri()))
.send()
.await
.unwrap();
assert_eq!(resp.status(), 200);
}
}