use crate::{
api::{
agent::{agent_service_client::AgentServiceClient, PrepareArtifactRequest},
artifact::{
artifact_service_client::ArtifactServiceClient, Artifact, ArtifactRequest,
ArtifactSystem, ArtifactsRequest, ArtifactsResponse, GetArtifactAliasRequest,
},
context::context_service_server::{ContextService, ContextServiceServer},
},
artifact::system::get_system,
cli::{Cli, Command},
};
use anyhow::{anyhow, bail, Context, Result};
use clap::Parser;
use http::uri::{InvalidUri, Uri};
use oauth2::{basic::BasicClient, AuthUrl, ClientId, RefreshToken, TokenResponse, TokenUrl};
use serde::{Deserialize, Serialize};
use sha256::digest;
use std::{
collections::{BTreeMap, HashMap},
path::{Path, PathBuf},
};
use tokio::{
fs::{read, OpenOptions},
io::AsyncWriteExt,
};
use tonic::{
metadata::{Ascii, MetadataValue},
transport::{Certificate, Channel, ClientTlsConfig, Server},
Code::NotFound,
Request, Response, Status,
};
use tracing::info;
#[derive(Clone)]
pub struct ConfigContextStore {
artifact: HashMap<String, Artifact>,
artifact_input_cache: HashMap<String, String>,
variable: HashMap<String, String>,
}
#[derive(Clone)]
pub struct ConfigContext {
artifact: String,
artifact_context: PathBuf,
artifact_namespace: String,
artifact_system: ArtifactSystem,
artifact_unlock: bool,
client_agent: AgentServiceClient<Channel>,
client_artifact: ArtifactServiceClient<Channel>,
port: u16,
registry: String,
store: ConfigContextStore,
}
#[derive(Clone)]
pub struct ConfigServer {
pub store: ConfigContextStore,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct VorpalCredentialsContent {
pub access_token: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub audience: Option<String>,
pub client_id: String,
pub expires_in: u64,
pub issued_at: u64,
pub refresh_token: String,
pub scopes: Vec<String>,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct VorpalCredentials {
pub issuer: BTreeMap<String, VorpalCredentialsContent>,
pub registry: BTreeMap<String, String>,
}
pub const DEFAULT_NAMESPACE: &str = "library";
pub const DEFAULT_TAG: &str = "latest";
#[derive(Clone, Debug, PartialEq)]
pub struct ArtifactAlias {
pub name: String,
pub namespace: String,
pub tag: String,
}
fn is_valid_component(s: &str) -> bool {
!s.is_empty()
&& s.chars()
.all(|c| c.is_ascii_alphanumeric() || matches!(c, '-' | '.' | '_' | '+'))
}
pub fn parse_artifact_alias(alias: &str) -> Result<ArtifactAlias> {
if alias.is_empty() {
bail!("alias cannot be empty");
}
if alias.len() > 255 {
bail!("alias too long (max 255 characters)");
}
let (base, tag) = match alias.rsplit_once(':') {
Some((_, "")) => bail!("tag cannot be empty"),
Some((b, t)) => (b, t.to_string()),
None => (alias, String::new()),
};
let (namespace, name) = match base.split_once('/') {
Some(("", _)) => bail!("namespace cannot be empty"),
Some((_ns, rest)) if rest.contains('/') => {
bail!("invalid format: too many path separators")
}
Some((ns, name)) => (ns.to_string(), name.to_string()),
None => (String::new(), base.to_string()),
};
if name.is_empty() {
bail!("name is required");
}
if !is_valid_component(&name) {
bail!("name contains invalid characters (allowed: alphanumeric, hyphens, dots, underscores, plus signs)");
}
if !namespace.is_empty() && !is_valid_component(&namespace) {
bail!("namespace contains invalid characters (allowed: alphanumeric, hyphens, dots, underscores, plus signs)");
}
if !tag.is_empty() && !is_valid_component(&tag) {
bail!("tag contains invalid characters (allowed: alphanumeric, hyphens, dots, underscores, plus signs)");
}
let tag = if tag.is_empty() {
DEFAULT_TAG.to_string()
} else {
tag
};
let namespace = if namespace.is_empty() {
DEFAULT_NAMESPACE.to_string()
} else {
namespace
};
Ok(ArtifactAlias {
name,
namespace,
tag,
})
}
impl ConfigServer {
pub fn new(store: ConfigContextStore) -> Self {
Self { store }
}
}
#[tonic::async_trait]
impl ContextService for ConfigServer {
async fn get_artifact(
&self,
request: Request<ArtifactRequest>,
) -> Result<Response<Artifact>, Status> {
let request = request.into_inner();
if request.digest.is_empty() {
return Err(tonic::Status::invalid_argument("'digest' is required"));
}
let artifact = self.store.artifact.get(request.digest.as_str());
if artifact.is_none() {
return Err(tonic::Status::not_found("artifact not found"));
}
Ok(Response::new(artifact.unwrap().clone()))
}
async fn get_artifacts(
&self,
_: tonic::Request<ArtifactsRequest>,
) -> Result<tonic::Response<ArtifactsResponse>, tonic::Status> {
let mut digests: Vec<String> = self.store.artifact.keys().cloned().collect();
digests.sort();
let response = ArtifactsResponse { digests };
Ok(Response::new(response))
}
}
pub async fn get_context() -> Result<ConfigContext> {
let args = Cli::parse();
match args.command {
Command::Start {
agent,
artifact,
artifact_context,
artifact_namespace,
artifact_system,
artifact_unlock,
artifact_variable,
port,
registry,
} => {
let client_agent_channel = build_channel(&agent).await?;
let client_registry_channel = build_channel(®istry).await?;
let client_agent = AgentServiceClient::new(client_agent_channel);
let client_artifact = ArtifactServiceClient::new(client_registry_channel);
Ok(ConfigContext::new(
artifact,
PathBuf::from(artifact_context),
artifact_namespace,
artifact_system,
artifact_unlock,
artifact_variable,
client_agent,
client_artifact,
port,
registry,
)?)
}
}
}
impl ConfigContext {
#[allow(clippy::too_many_arguments)]
pub fn new(
artifact: String,
artifact_context: PathBuf,
artifact_namespace: String,
artifact_system: String,
artifact_unlock: bool,
artifact_variable: Vec<String>,
client_agent: AgentServiceClient<Channel>,
client_artifact: ArtifactServiceClient<Channel>,
port: u16,
registry: String,
) -> Result<Self> {
Ok(Self {
artifact,
artifact_context,
client_agent,
client_artifact,
artifact_namespace,
port,
registry,
store: ConfigContextStore {
artifact: HashMap::new(),
artifact_input_cache: HashMap::new(),
variable: artifact_variable
.iter()
.map(|v| {
let mut parts = v.split('=');
let name = parts.next().unwrap_or_default();
let value = parts.next().unwrap_or_default();
(name.to_string(), value.to_string())
})
.collect(),
},
artifact_system: get_system(&artifact_system)?,
artifact_unlock,
})
}
pub async fn add_artifact(&mut self, artifact: &Artifact) -> Result<String> {
if artifact.name.is_empty() {
bail!("name cannot be empty");
}
if artifact.steps.is_empty() {
bail!("steps cannot be empty");
}
if artifact.systems.is_empty() {
bail!("systems cannot be empty");
}
if !artifact.systems.contains(&artifact.target) {
bail!(
"artifact '{}' does not support system '{:?}' (supported: {:?})",
artifact.name,
ArtifactSystem::try_from(artifact.target).unwrap_or(ArtifactSystem::UnknownSystem),
artifact
.systems
.iter()
.filter_map(|&s| ArtifactSystem::try_from(s).ok())
.collect::<Vec<_>>()
);
}
let artifact_json =
serde_json::to_vec(&artifact).expect("failed to serialize artifact to JSON");
let input_digest = digest(artifact_json.clone());
if self.store.artifact.contains_key(&input_digest) {
return Ok(input_digest);
}
if let Some(output_digest) = self.store.artifact_input_cache.get(&input_digest) {
if self.store.artifact.contains_key(output_digest) {
return Ok(output_digest.clone());
}
}
let request = PrepareArtifactRequest {
artifact: Some(artifact.clone()),
artifact_context: self.artifact_context.display().to_string(),
artifact_namespace: self.artifact_namespace.clone(),
artifact_unlock: self.artifact_unlock,
registry: self.registry.clone(),
};
let mut request = Request::new(request);
let request_auth = client_auth_header(&self.registry).await?;
if let Some(header) = request_auth {
request.metadata_mut().insert("authorization", header);
}
let response = self
.client_agent
.prepare_artifact(request)
.await
.expect("failed to prepare artifact");
let mut response = response.into_inner();
let mut response_artifact = None;
let mut response_artifact_digest = None;
loop {
match response.message().await {
Ok(Some(message)) => {
if let Some(artifact_output) = message.artifact_output {
if self.port == 0 {
info!("{} |> {}", artifact.name, artifact_output);
} else {
println!("{} |> {}", artifact.name, artifact_output);
}
}
response_artifact = message.artifact;
response_artifact_digest = message.artifact_digest;
}
Ok(None) => break,
Err(status) => {
if status.code() != NotFound {
bail!("{}", status.message());
}
break;
}
}
}
if response_artifact.is_none() {
bail!("artifact not returned from agent service");
}
if response_artifact_digest.is_none() {
bail!("artifact digest not returned from agent service");
}
let artifact = response_artifact.unwrap();
let artifact_digest = response_artifact_digest.unwrap();
self.store
.artifact
.insert(artifact_digest.clone(), artifact.clone());
self.store
.artifact_input_cache
.insert(input_digest, artifact_digest.clone());
Ok(artifact_digest)
}
pub async fn fetch_artifact(&mut self, digest: &str) -> Result<String> {
self.fetch_artifact_in_namespace(digest, &self.artifact_namespace.clone())
.await
}
async fn fetch_artifact_in_namespace(
&mut self,
digest: &str,
namespace: &str,
) -> Result<String> {
if self.store.artifact.contains_key(digest) {
return Ok(digest.to_string());
}
let request = ArtifactRequest {
digest: digest.to_string(),
namespace: namespace.to_string(),
};
let mut request = Request::new(request.clone());
let request_auth = client_auth_header(&self.registry).await?;
if let Some(header) = request_auth {
request.metadata_mut().insert("authorization", header);
}
match self.client_artifact.get_artifact(request).await {
Err(status) => {
if status.code() != NotFound {
bail!("artifact service error: {:?}", status);
}
bail!("artifact not found: {}", digest);
}
Ok(response) => {
let artifact = response.into_inner();
self.store
.artifact
.insert(digest.to_string(), artifact.clone());
for step in artifact.steps.iter() {
for dep in step.artifacts.iter() {
Box::pin(self.fetch_artifact_in_namespace(dep, namespace)).await?;
}
}
Ok(digest.to_string())
}
}
}
pub async fn fetch_artifact_alias(&mut self, alias: &str) -> Result<String> {
let alias_parsed = parse_artifact_alias(alias)?;
let request = GetArtifactAliasRequest {
system: self.artifact_system.into(),
name: alias_parsed.name,
namespace: alias_parsed.namespace.clone(),
tag: alias_parsed.tag,
};
let mut request = Request::new(request);
let request_auth = client_auth_header(&self.registry).await?;
if let Some(header) = request_auth {
request.metadata_mut().insert("authorization", header);
}
let response = self
.client_artifact
.get_artifact_alias(request)
.await
.map_err(|status| {
if status.code() == NotFound {
anyhow!("alias not found in registry: {}", alias)
} else {
anyhow!("registry error: {:?}", status)
}
})?;
let digest = response.into_inner().digest;
if digest.is_empty() {
bail!("registry returned empty digest for alias: {}", alias);
}
if self.store.artifact.contains_key(&digest) {
return Ok(digest);
}
self.fetch_artifact_in_namespace(&digest, &alias_parsed.namespace)
.await?;
Ok(digest)
}
pub fn get_artifact_store(&self) -> HashMap<String, Artifact> {
self.store.artifact.clone()
}
pub fn get_artifact(&self, digest: &str) -> Option<Artifact> {
self.store.artifact.get(digest).cloned()
}
pub fn get_artifact_context_path(&self) -> &PathBuf {
&self.artifact_context
}
pub fn get_artifact_name(&self) -> &str {
self.artifact.as_str()
}
pub fn get_artifact_namespace(&self) -> &str {
self.artifact_namespace.as_str()
}
pub fn get_system(&self) -> ArtifactSystem {
self.artifact_system
}
pub fn get_variable(&self, name: &str) -> Option<String> {
self.store.variable.get(name).cloned()
}
pub async fn run(&self) -> Result<()> {
let service = ContextServiceServer::new(ConfigServer::new(self.store.clone()));
let service_addr_str = format!("[::]:{}", self.port);
let service_addr = service_addr_str.parse().expect("failed to parse address");
println!("context service: {service_addr_str}");
Server::builder()
.add_service(service)
.serve(service_addr)
.await
.map_err(|e| anyhow::anyhow!("failed to serve: {}", e))
}
}
pub fn get_root_dir_path() -> PathBuf {
Path::new("/var/lib/vorpal").to_path_buf()
}
pub fn get_root_key_dir_path() -> PathBuf {
get_root_dir_path().join("key")
}
pub fn get_key_ca_path() -> PathBuf {
get_root_key_dir_path().join("ca").with_extension("pem")
}
pub fn get_key_credentials_path() -> PathBuf {
get_root_key_dir_path()
.join("credentials")
.with_extension("json")
}
async fn get_client_tls_config(uri: &str) -> Result<Option<ClientTlsConfig>> {
if uri.starts_with("http://") || uri.starts_with("unix://") {
return Ok(None);
}
let ca_pem_path = get_key_ca_path();
let mut client_tls_config = ClientTlsConfig::new().with_native_roots();
if ca_pem_path.exists() {
let ca_pem = read(&ca_pem_path)
.await
.with_context(|| format!("failed to read CA certificate: {}", ca_pem_path.display()))?;
client_tls_config = client_tls_config.ca_certificate(Certificate::from_pem(ca_pem));
}
Ok(Some(client_tls_config))
}
pub async fn build_channel(uri: &str) -> Result<Channel> {
if uri.starts_with("unix://") {
let socket_path = uri.strip_prefix("unix://").unwrap().to_string();
let channel = Channel::from_static("http://[::]:50051").connect_with_connector_lazy(
tower::service_fn(move |_: tonic::transport::Uri| {
let path = socket_path.clone();
async move {
Ok::<_, std::io::Error>(hyper_util::rt::TokioIo::new(
tokio::net::UnixStream::connect(path).await?,
))
}
}),
);
return Ok(channel);
}
if !uri.starts_with("http://") && !uri.starts_with("https://") {
bail!("URI must start with http://, https://, or unix://: {}", uri);
}
let parsed_uri = uri
.parse::<Uri>()
.map_err(|e: InvalidUri| anyhow!("invalid URI: {}", e))?;
let tls_config = get_client_tls_config(uri).await?;
let mut endpoint = Channel::builder(parsed_uri);
if let Some(tls) = tls_config {
endpoint = endpoint.tls_config(tls)?;
}
endpoint
.connect()
.await
.with_context(|| format!("failed to connect to {}", uri))
}
async fn refresh_access_token(
audience: Option<&str>,
client_id: &str,
issuer: &str,
refresh_token: &str,
) -> Result<(String, u64, u64, Option<String>)> {
let discovery_url = format!("{}/.well-known/openid-configuration", issuer);
let doc: serde_json::Value = reqwest::get(&discovery_url).await?.json().await?;
let token_endpoint = doc
.get("token_endpoint")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow!("missing token_endpoint in OIDC discovery"))?;
let client = BasicClient::new(ClientId::new(client_id.to_string()))
.set_auth_uri(AuthUrl::new(issuer.to_string())?)
.set_token_uri(TokenUrl::new(token_endpoint.to_string())?);
let http_client = reqwest::Client::new();
let refresh_token_obj = RefreshToken::new(refresh_token.to_string());
let mut request = client.exchange_refresh_token(&refresh_token_obj);
if let Some(aud) = audience {
request = request.add_extra_param("audience", aud);
}
let token_result = request.request_async(&http_client).await?;
let new_access_token = token_result.access_token().secret().to_string();
let new_expires_in = token_result
.expires_in()
.map(|d| d.as_secs())
.unwrap_or(3600);
let new_refresh_token = normalize_rotated_refresh_token(
token_result.refresh_token().map(|t| t.secret().to_string()),
);
let issued_at = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)?
.as_secs();
Ok((
new_access_token,
new_expires_in,
issued_at,
new_refresh_token,
))
}
fn normalize_rotated_refresh_token(raw: Option<String>) -> Option<String> {
raw.filter(|s| !s.is_empty())
}
fn apply_token_refresh(
creds: &mut VorpalCredentialsContent,
access_token: String,
expires_in: u64,
issued_at: u64,
rotated_refresh_token: Option<String>,
) {
creds.access_token = access_token;
creds.expires_in = expires_in;
creds.issued_at = issued_at;
if let Some(new) = rotated_refresh_token {
creds.refresh_token = new;
}
}
async fn write_credentials_secure(path: &Path, bytes: &[u8]) -> Result<()> {
let mut file = OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.mode(0o600)
.open(path)
.await?;
file.write_all(bytes).await?;
file.flush().await?;
Ok(())
}
pub async fn client_auth_header(registry: &str) -> Result<Option<MetadataValue<Ascii>>> {
let credentials_path = get_key_credentials_path();
if !credentials_path.exists() {
return Ok(None);
}
let credentials_data = read(&credentials_path).await?;
let mut credentials: VorpalCredentials = serde_json::from_slice(&credentials_data)?;
let registry_issuer = match credentials.registry.get(registry) {
Some(issuer) => issuer.clone(),
None => return Ok(None),
};
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)?
.as_secs();
let needs_refresh = {
let issuer_creds = credentials
.issuer
.get(®istry_issuer)
.ok_or_else(|| anyhow!("no credentials for issuer: {}", registry_issuer))?;
let token_age = now - issuer_creds.issued_at;
let expires_in = issuer_creds.expires_in;
token_age + 300 >= expires_in
};
if needs_refresh {
let (audience, client_id, refresh_token) = {
let issuer_creds = credentials
.issuer
.get(®istry_issuer)
.ok_or_else(|| anyhow!("no credentials for issuer: {}", registry_issuer))?;
(
issuer_creds.audience.clone(),
issuer_creds.client_id.clone(),
issuer_creds.refresh_token.clone(),
)
};
if refresh_token.is_empty() {
return Err(anyhow!(
"Access token expired and no refresh token available. Please run: vorpal login --issuer {}",
registry_issuer
));
}
let (new_token, new_expires, new_issued_at, rotated_refresh) = refresh_access_token(
audience.as_deref(),
&client_id,
®istry_issuer,
&refresh_token,
)
.await?;
let issuer_creds = credentials
.issuer
.get_mut(®istry_issuer)
.ok_or_else(|| anyhow!("no credentials for issuer: {}", registry_issuer))?;
apply_token_refresh(
issuer_creds,
new_token,
new_expires,
new_issued_at,
rotated_refresh,
);
let credentials_json = serde_json::to_string_pretty(&credentials)?;
write_credentials_secure(&credentials_path, credentials_json.as_bytes()).await?;
}
let access_token = credentials
.issuer
.get(®istry_issuer)
.ok_or_else(|| anyhow!("no credentials for issuer: {}", registry_issuer))?
.access_token
.clone();
let header = format!("Bearer {}", access_token)
.parse()
.map_err(|e| anyhow!("failed to parse Bearer token: {}", e))?;
Ok(Some(header))
}
#[cfg(test)]
mod tests {
use super::*;
fn sample_creds() -> VorpalCredentialsContent {
VorpalCredentialsContent {
access_token: "old-access".to_string(),
audience: Some("aud-1".to_string()),
client_id: "client-1".to_string(),
expires_in: 3600,
issued_at: 1_700_000_000,
refresh_token: "old-refresh".to_string(),
scopes: vec!["openid".to_string(), "offline_access".to_string()],
}
}
#[test]
fn apply_token_refresh_replaces_refresh_when_rotated() {
let mut creds = sample_creds();
apply_token_refresh(
&mut creds,
"new-access".to_string(),
7200,
1_700_000_500,
Some("rotated-refresh".to_string()),
);
assert_eq!(creds.access_token, "new-access");
assert_eq!(creds.expires_in, 7200);
assert_eq!(creds.issued_at, 1_700_000_500);
assert_eq!(creds.refresh_token, "rotated-refresh");
assert_eq!(creds.audience.as_deref(), Some("aud-1"));
assert_eq!(creds.client_id, "client-1");
assert_eq!(creds.scopes, vec!["openid", "offline_access"]);
}
#[test]
fn apply_token_refresh_keeps_refresh_when_not_rotated() {
let mut creds = sample_creds();
apply_token_refresh(
&mut creds,
"new-access".to_string(),
7200,
1_700_000_500,
None,
);
assert_eq!(creds.access_token, "new-access");
assert_eq!(creds.expires_in, 7200);
assert_eq!(creds.issued_at, 1_700_000_500);
assert_eq!(creds.refresh_token, "old-refresh");
assert_eq!(creds.audience.as_deref(), Some("aud-1"));
assert_eq!(creds.client_id, "client-1");
assert_eq!(creds.scopes, vec!["openid", "offline_access"]);
}
#[test]
fn normalize_rotated_refresh_token_some_nonempty_passes_through() {
assert_eq!(
normalize_rotated_refresh_token(Some("rotated-refresh".to_string())),
Some("rotated-refresh".to_string())
);
}
#[test]
fn normalize_rotated_refresh_token_some_empty_becomes_none() {
assert_eq!(
normalize_rotated_refresh_token(Some(String::new())),
None,
"empty-string refresh_token must be treated as not-rotated for parity with Go/TS"
);
}
#[test]
fn normalize_rotated_refresh_token_none_passes_through() {
assert_eq!(normalize_rotated_refresh_token(None), None);
}
#[test]
fn write_credentials_secure_creates_file_with_mode_0o600() {
use std::os::unix::fs::PermissionsExt;
let dir = std::env::temp_dir().join(format!(
"vorpal-creds-mode-test-{}-{}",
std::process::id(),
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos()
));
std::fs::create_dir_all(&dir).expect("create temp dir");
let path = dir.join("credentials.json");
assert!(!path.exists(), "test path must be previously-nonexistent");
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("build runtime");
runtime
.block_on(write_credentials_secure(&path, b"{\"hello\":\"world\"}"))
.expect("write credentials");
let mode = std::fs::metadata(&path)
.expect("stat credentials")
.permissions()
.mode();
assert_eq!(
mode & 0o777,
0o600,
"credentials file must be born 0o600, got {:o}",
mode & 0o777
);
let _ = std::fs::remove_file(&path);
let _ = std::fs::remove_dir(&dir);
}
#[test]
fn apply_token_refresh_persists_through_serde_roundtrip() {
let mut creds = sample_creds();
apply_token_refresh(
&mut creds,
"new-access".to_string(),
7200,
1_700_000_500,
Some("rotated-refresh".to_string()),
);
let json = serde_json::to_string(&creds).expect("serialize");
let parsed: VorpalCredentialsContent = serde_json::from_str(&json).expect("deserialize");
assert_eq!(parsed.access_token, "new-access");
assert_eq!(parsed.refresh_token, "rotated-refresh");
assert_eq!(parsed.expires_in, 7200);
assert_eq!(parsed.issued_at, 1_700_000_500);
assert_eq!(parsed.audience.as_deref(), Some("aud-1"));
assert_eq!(parsed.client_id, "client-1");
assert_eq!(parsed.scopes, vec!["openid", "offline_access"]);
}
}