use std::time::Duration;
use anyhow::{Context, Result, anyhow};
use astrid_core::PrincipalId;
use astrid_core::kernel_api::{KernelRequest, KernelResponse};
use astrid_types::ipc::{IpcMessage, IpcPayload};
use uuid::Uuid;
use crate::socket_client::SocketClient;
const REQUEST_PREFIX: &str = "astrid.v1.request.";
const RESPONSE_PREFIX: &str = "astrid.v1.response.";
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(15);
#[must_use]
pub const fn topic_suffix(req: &KernelRequest) -> &'static str {
match req {
KernelRequest::InstallCapsule { .. } => "install_capsule",
KernelRequest::ApproveCapability { .. } => "approve_capability",
KernelRequest::ListCapsules => "list_capsules",
KernelRequest::ReloadCapsules => "reload_capsules",
KernelRequest::GetCommands => "get_commands",
KernelRequest::GetCapsuleMetadata => "metadata",
KernelRequest::Shutdown { .. } => "shutdown",
KernelRequest::GetStatus => "status",
}
}
pub struct KernelClient {
inner: SocketClient,
caller: PrincipalId,
timeout: Duration,
}
impl KernelClient {
pub async fn connect(caller: PrincipalId) -> Result<Self> {
let session_id = astrid_core::SessionId::from_uuid(Uuid::new_v4());
let inner = SocketClient::connect(session_id)
.await
.context("Failed to connect to Astrid daemon. Run `astrid start` to launch it.")?;
Ok(Self {
inner,
caller,
timeout: DEFAULT_TIMEOUT,
})
}
#[must_use]
pub const fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
pub async fn request(&mut self, req: KernelRequest) -> Result<KernelResponse> {
let correlation = Uuid::new_v4().simple().to_string();
let suffix = format!("{}.{correlation}", topic_suffix(&req));
let request_topic = format!("{REQUEST_PREFIX}{suffix}");
let want_response = format!("{RESPONSE_PREFIX}{suffix}");
let payload = serde_json::to_value(&req).context("serialise KernelRequest")?;
let msg = IpcMessage::new(request_topic, IpcPayload::RawJson(payload), Uuid::nil())
.with_principal(self.caller.to_string());
self.inner.send_message(msg).await?;
let raw = self
.inner
.read_until_topic(&want_response, self.timeout)
.await
.with_context(|| format!("waiting on {want_response}"))?;
SocketClient::extract_kernel_response(&raw).ok_or_else(|| {
anyhow!("kernel response on {want_response} did not deserialize as KernelResponse")
})
}
#[must_use]
pub const fn caller(&self) -> &PrincipalId {
&self.caller
}
}
pub fn into_result(resp: KernelResponse) -> Result<KernelResponse> {
match resp {
KernelResponse::Error(msg) => Err(anyhow!("kernel rejected request: {msg}")),
other => Ok(other),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn topic_suffixes_match_cli_conventions() {
assert_eq!(topic_suffix(&KernelRequest::GetStatus), "status");
assert_eq!(topic_suffix(&KernelRequest::ListCapsules), "list_capsules");
assert_eq!(topic_suffix(&KernelRequest::GetCommands), "get_commands");
assert_eq!(topic_suffix(&KernelRequest::GetCapsuleMetadata), "metadata");
assert_eq!(
topic_suffix(&KernelRequest::ReloadCapsules),
"reload_capsules"
);
assert_eq!(
topic_suffix(&KernelRequest::Shutdown { reason: None }),
"shutdown"
);
}
#[test]
fn into_result_lifts_error_variant() {
let err = KernelResponse::Error("not allowed".into());
let res = into_result(err);
assert!(res.is_err());
assert!(res.unwrap_err().to_string().contains("not allowed"));
}
}