use std::pin::Pin;
use std::sync::Arc;
use anyhow::Result;
use async_trait::async_trait;
use futures::{Stream, StreamExt};
use tokio_util::sync::CancellationToken;
use super::{
Discovery, DiscoveryEvent, DiscoveryInstance, DiscoveryQuery, DiscoverySpec, DiscoveryStream,
};
use crate::storage::key_value_store::{KeyValueStoreManager, WatchEvent};
const INSTANCES_BUCKET: &str = "v1/instances";
const MODELS_BUCKET: &str = "v1/mdc";
pub struct KVStoreDiscovery {
store: Arc<KeyValueStoreManager>,
cancel_token: CancellationToken,
}
impl KVStoreDiscovery {
pub fn new(store: KeyValueStoreManager, cancel_token: CancellationToken) -> Self {
Self {
store: Arc::new(store),
cancel_token,
}
}
fn endpoint_key(namespace: &str, component: &str, endpoint: &str, instance_id: u64) -> String {
format!("{}/{}/{}/{:x}", namespace, component, endpoint, instance_id)
}
fn model_key(namespace: &str, component: &str, endpoint: &str, instance_id: u64) -> String {
format!("{}/{}/{}/{:x}", namespace, component, endpoint, instance_id)
}
fn query_prefix(query: &DiscoveryQuery) -> String {
match query {
DiscoveryQuery::AllEndpoints => INSTANCES_BUCKET.to_string(),
DiscoveryQuery::NamespacedEndpoints { namespace } => {
format!("{}/{}", INSTANCES_BUCKET, namespace)
}
DiscoveryQuery::ComponentEndpoints {
namespace,
component,
} => {
format!("{}/{}/{}", INSTANCES_BUCKET, namespace, component)
}
DiscoveryQuery::Endpoint {
namespace,
component,
endpoint,
} => {
format!(
"{}/{}/{}/{}",
INSTANCES_BUCKET, namespace, component, endpoint
)
}
DiscoveryQuery::AllModels => MODELS_BUCKET.to_string(),
DiscoveryQuery::NamespacedModels { namespace } => {
format!("{}/{}", MODELS_BUCKET, namespace)
}
DiscoveryQuery::ComponentModels {
namespace,
component,
} => {
format!("{}/{}/{}", MODELS_BUCKET, namespace, component)
}
DiscoveryQuery::EndpointModels {
namespace,
component,
endpoint,
} => {
format!("{}/{}/{}/{}", MODELS_BUCKET, namespace, component, endpoint)
}
}
}
fn strip_bucket_prefix<'a>(key: &'a str, bucket_name: &str) -> &'a str {
if let Some(stripped) = key.strip_prefix(bucket_name) {
stripped.strip_prefix('/').unwrap_or(stripped)
} else {
key
}
}
fn matches_prefix(key_str: &str, prefix: &str, bucket_name: &str) -> bool {
let relative_key = Self::strip_bucket_prefix(key_str, bucket_name);
let relative_prefix = Self::strip_bucket_prefix(prefix, bucket_name);
if relative_prefix.is_empty() {
return true;
}
relative_key.starts_with(relative_prefix)
}
fn parse_instance(value: &[u8]) -> Result<DiscoveryInstance> {
let instance: DiscoveryInstance = serde_json::from_slice(value)?;
Ok(instance)
}
}
#[async_trait]
impl Discovery for KVStoreDiscovery {
fn instance_id(&self) -> u64 {
self.store.connection_id()
}
async fn register(&self, spec: DiscoverySpec) -> Result<DiscoveryInstance> {
let instance_id = self.instance_id();
let instance = spec.with_instance_id(instance_id);
let (bucket_name, key_path) = match &instance {
DiscoveryInstance::Endpoint(inst) => {
let key = Self::endpoint_key(
&inst.namespace,
&inst.component,
&inst.endpoint,
inst.instance_id,
);
tracing::debug!(
"KVStoreDiscovery::register: Registering endpoint instance_id={}, namespace={}, component={}, endpoint={}, key={}",
inst.instance_id,
inst.namespace,
inst.component,
inst.endpoint,
key
);
(INSTANCES_BUCKET, key)
}
DiscoveryInstance::Model {
namespace,
component,
endpoint,
instance_id,
..
} => {
let key = Self::model_key(namespace, component, endpoint, *instance_id);
tracing::debug!(
"KVStoreDiscovery::register: Registering model instance_id={}, namespace={}, component={}, endpoint={}, key={}",
instance_id,
namespace,
component,
endpoint,
key
);
(MODELS_BUCKET, key)
}
};
let instance_json = serde_json::to_vec(&instance)?;
tracing::debug!(
"KVStoreDiscovery::register: Serialized instance to {} bytes for key={}",
instance_json.len(),
key_path
);
tracing::debug!(
"KVStoreDiscovery::register: Getting/creating bucket={} for key={}",
bucket_name,
key_path
);
let bucket = self.store.get_or_create_bucket(bucket_name, None).await?;
let key = crate::storage::key_value_store::Key::from_raw(key_path.clone());
tracing::debug!(
"KVStoreDiscovery::register: Inserting into bucket={}, key={}",
bucket_name,
key_path
);
let outcome = bucket.insert(&key, instance_json.into(), 0).await?;
tracing::debug!(
"KVStoreDiscovery::register: Successfully registered instance_id={}, key={}, outcome={:?}",
instance_id,
key_path,
outcome
);
Ok(instance)
}
async fn list(&self, query: DiscoveryQuery) -> Result<Vec<DiscoveryInstance>> {
let prefix = Self::query_prefix(&query);
let bucket_name = if prefix.starts_with(INSTANCES_BUCKET) {
INSTANCES_BUCKET
} else {
MODELS_BUCKET
};
let Some(bucket) = self.store.get_bucket(bucket_name).await? else {
return Ok(Vec::new());
};
let entries = bucket.entries().await?;
let mut instances = Vec::new();
for (key_str, value) in entries {
if Self::matches_prefix(&key_str, &prefix, bucket_name) {
match Self::parse_instance(&value) {
Ok(instance) => instances.push(instance),
Err(e) => {
tracing::warn!(key = %key_str, error = %e, "Failed to parse discovery instance");
}
}
}
}
Ok(instances)
}
async fn list_and_watch(
&self,
query: DiscoveryQuery,
cancel_token: Option<CancellationToken>,
) -> Result<DiscoveryStream> {
let prefix = Self::query_prefix(&query);
let bucket_name = if prefix.starts_with(INSTANCES_BUCKET) {
INSTANCES_BUCKET
} else {
MODELS_BUCKET
};
tracing::debug!(
"KVStoreDiscovery::list_and_watch: Starting watch for query={:?}, prefix={}, bucket={}",
query,
prefix,
bucket_name
);
let cancel_token = cancel_token.unwrap_or_else(|| self.cancel_token.clone());
let (_, mut rx) = self.store.clone().watch(
bucket_name,
None, cancel_token,
);
tracing::debug!(
"KVStoreDiscovery::list_and_watch: Got watch receiver for bucket={}",
bucket_name
);
let stream = async_stream::stream! {
let mut event_count = 0;
tracing::debug!("KVStoreDiscovery::list_and_watch: Stream started, waiting for events on prefix={}", prefix);
while let Some(event) = rx.recv().await {
event_count += 1;
tracing::debug!(
"KVStoreDiscovery::list_and_watch: Received event #{} for prefix={}",
event_count,
prefix
);
let discovery_event = match event {
WatchEvent::Put(kv) => {
tracing::debug!(
"KVStoreDiscovery::list_and_watch: Put event, key={}, prefix={}, matches={}",
kv.key_str(),
prefix,
Self::matches_prefix(kv.key_str(), &prefix, bucket_name)
);
if !Self::matches_prefix(kv.key_str(), &prefix, bucket_name) {
tracing::debug!(
"KVStoreDiscovery::list_and_watch: Skipping key {} (doesn't match prefix {})",
kv.key_str(),
prefix
);
continue;
}
match Self::parse_instance(kv.value()) {
Ok(instance) => {
tracing::debug!(
"KVStoreDiscovery::list_and_watch: Emitting Added event for instance_id={}, key={}",
instance.instance_id(),
kv.key_str()
);
Some(DiscoveryEvent::Added(instance))
},
Err(e) => {
tracing::warn!(
key = %kv.key_str(),
error = %e,
"Failed to parse discovery instance from watch event"
);
None
}
}
}
WatchEvent::Delete(kv) => {
let key_str = kv.as_ref();
tracing::debug!(
"KVStoreDiscovery::list_and_watch: Delete event, key={}, prefix={}",
key_str,
prefix
);
if !Self::matches_prefix(key_str, &prefix, bucket_name) {
tracing::debug!(
"KVStoreDiscovery::list_and_watch: Skipping deleted key {} (doesn't match prefix {})",
key_str,
prefix
);
continue;
}
let key_parts: Vec<&str> = key_str.split('/').collect();
match key_parts.last() {
Some(instance_id_hex) => {
match u64::from_str_radix(instance_id_hex, 16) {
Ok(instance_id) => {
tracing::debug!(
"KVStoreDiscovery::list_and_watch: Emitting Removed event for instance_id={}, key={}",
instance_id,
key_str
);
Some(DiscoveryEvent::Removed(instance_id))
}
Err(e) => {
tracing::warn!(
key = %key_str,
error = %e,
"Failed to parse instance_id hex from deleted key"
);
None
}
}
}
None => {
tracing::warn!(
key = %key_str,
"Delete event key has no path components"
);
None
}
}
}
};
if let Some(event) = discovery_event {
tracing::debug!("KVStoreDiscovery::list_and_watch: Yielding event: {:?}", event);
yield Ok(event);
} else {
tracing::debug!("KVStoreDiscovery::list_and_watch: Event was filtered out (None)");
}
}
tracing::debug!("KVStoreDiscovery::list_and_watch: Stream ended after {} events for prefix={}", event_count, prefix);
};
tracing::debug!(
"KVStoreDiscovery::list_and_watch: Returning stream for query={:?}",
query
);
Ok(Box::pin(stream))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::component::TransportType;
#[tokio::test]
async fn test_kv_store_discovery_register_endpoint() {
let store = KeyValueStoreManager::memory();
let cancel_token = CancellationToken::new();
let client = KVStoreDiscovery::new(store, cancel_token);
let spec = DiscoverySpec::Endpoint {
namespace: "test".to_string(),
component: "comp1".to_string(),
endpoint: "ep1".to_string(),
transport: TransportType::Nats("nats://localhost:4222".to_string()),
};
let instance = client.register(spec).await.unwrap();
match instance {
DiscoveryInstance::Endpoint(inst) => {
assert_eq!(inst.namespace, "test");
assert_eq!(inst.component, "comp1");
assert_eq!(inst.endpoint, "ep1");
}
_ => panic!("Expected Endpoint instance"),
}
}
#[tokio::test]
async fn test_kv_store_discovery_list() {
let store = KeyValueStoreManager::memory();
let cancel_token = CancellationToken::new();
let client = KVStoreDiscovery::new(store, cancel_token);
let spec1 = DiscoverySpec::Endpoint {
namespace: "ns1".to_string(),
component: "comp1".to_string(),
endpoint: "ep1".to_string(),
transport: TransportType::Nats("nats://localhost:4222".to_string()),
};
client.register(spec1).await.unwrap();
let spec2 = DiscoverySpec::Endpoint {
namespace: "ns1".to_string(),
component: "comp1".to_string(),
endpoint: "ep2".to_string(),
transport: TransportType::Nats("nats://localhost:4222".to_string()),
};
client.register(spec2).await.unwrap();
let spec3 = DiscoverySpec::Endpoint {
namespace: "ns2".to_string(),
component: "comp2".to_string(),
endpoint: "ep1".to_string(),
transport: TransportType::Nats("nats://localhost:4222".to_string()),
};
client.register(spec3).await.unwrap();
let all = client.list(DiscoveryQuery::AllEndpoints).await.unwrap();
assert_eq!(all.len(), 3);
let ns1 = client
.list(DiscoveryQuery::NamespacedEndpoints {
namespace: "ns1".to_string(),
})
.await
.unwrap();
assert_eq!(ns1.len(), 2);
let comp1 = client
.list(DiscoveryQuery::ComponentEndpoints {
namespace: "ns1".to_string(),
component: "comp1".to_string(),
})
.await
.unwrap();
assert_eq!(comp1.len(), 2);
}
#[tokio::test]
async fn test_kv_store_discovery_watch() {
let store = KeyValueStoreManager::memory();
let cancel_token = CancellationToken::new();
let client = Arc::new(KVStoreDiscovery::new(store, cancel_token.clone()));
let mut stream = client
.list_and_watch(DiscoveryQuery::AllEndpoints, None)
.await
.unwrap();
let client_clone = client.clone();
let register_task = tokio::spawn(async move {
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
let spec = DiscoverySpec::Endpoint {
namespace: "test".to_string(),
component: "comp1".to_string(),
endpoint: "ep1".to_string(),
transport: TransportType::Nats("nats://localhost:4222".to_string()),
};
client_clone.register(spec).await.unwrap();
});
let event = stream.next().await.unwrap().unwrap();
match event {
DiscoveryEvent::Added(instance) => match instance {
DiscoveryInstance::Endpoint(inst) => {
assert_eq!(inst.namespace, "test");
assert_eq!(inst.component, "comp1");
assert_eq!(inst.endpoint, "ep1");
}
_ => panic!("Expected Endpoint instance"),
},
_ => panic!("Expected Added event"),
}
register_task.await.unwrap();
cancel_token.cancel();
}
}