use std::{
collections::BTreeMap,
future::Future,
sync::Arc,
time::{Duration, Instant},
};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::{Map, Value, json};
use tokio::sync::{Mutex, OnceCell};
use crate::{
CommandResult, Credential, CredentialRequest, Dispatcher, Result, SchemaRegistry, Tier,
error::{CliCoreError, exit_code_for_error},
output::{
Envelope, HumanViewRegistry, OutputFormat, PipelineOpts, apply_pipeline,
build_error_envelope, is_valid_output_format, render_human_with_registry_selected,
},
};
pub type ValueMap = Map<String, Value>;
#[derive(Clone, Debug, Default, Eq, PartialEq)]
pub struct CommandMeta {
pub dry_run_prompt: bool,
pub auth_metadata: BTreeMap<String, String>,
pub scopes: Vec<String>,
}
impl CommandMeta {
#[must_use]
pub fn provider(&self) -> Option<&str> {
self.auth_metadata.get("provider").map(String::as_str)
}
#[must_use]
pub fn tier(&self) -> Tier {
self.auth_metadata
.get("tier")
.and_then(|value| value.parse::<Tier>().ok())
.unwrap_or(Tier::Read)
}
#[must_use]
pub fn fixed_env(&self) -> Option<&str> {
self.auth_metadata.get("fixed_env").map(String::as_str)
}
pub fn set_scopes(&mut self, scopes: Vec<String>) {
if scopes.is_empty() {
self.auth_metadata.remove("scopes");
} else {
self.auth_metadata
.insert("scopes".to_owned(), scopes.join(" "));
}
self.scopes = scopes;
}
}
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
#[non_exhaustive]
pub enum AuthRequirement {
#[default]
Required,
Optional,
None,
}
impl AuthRequirement {
#[must_use]
pub fn is_none(self) -> bool {
matches!(self, Self::None)
}
#[must_use]
pub fn is_required(self) -> bool {
matches!(self, Self::Required)
}
#[must_use]
pub fn is_optional(self) -> bool {
matches!(self, Self::Optional)
}
}
#[derive(Clone)]
pub struct CredentialResolver {
inner: Arc<ResolverInner>,
}
#[derive(Debug)]
struct ResolverInner {
auth: Dispatcher,
provider: String,
env: String,
command_path: String,
tier: String,
no_auth: bool,
meta: CommandMeta,
state: Mutex<ResolveState>,
cell: OnceCell<Credential>,
}
#[derive(Debug, Default)]
struct ResolveState {
credential: Option<Credential>,
requested: Vec<String>,
}
impl std::fmt::Debug for CredentialResolver {
fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
formatter
.debug_struct("CredentialResolver")
.field("provider", &self.inner.provider)
.field("env", &self.inner.env)
.field("no_auth", &self.inner.no_auth)
.field("resolved", &self.inner.cell.get().is_some())
.finish_non_exhaustive()
}
}
impl CredentialResolver {
fn new(
auth: Dispatcher,
provider: String,
env: String,
command_path: String,
tier: String,
no_auth: bool,
meta: CommandMeta,
) -> Self {
Self {
inner: Arc::new(ResolverInner {
auth,
provider,
env,
command_path,
tier,
no_auth,
meta,
state: Mutex::new(ResolveState::default()),
cell: OnceCell::new(),
}),
}
}
pub async fn resolve(&self) -> Result<Credential> {
if self.inner.no_auth {
return Err(CliCoreError::message(
"command is marked no_auth and has no credential",
));
}
self.resolve_scopes(&[]).await
}
pub async fn resolve_with_scopes(&self, extra: &[String]) -> Result<Credential> {
if self.inner.no_auth {
return Err(CliCoreError::message(
"command is marked no_auth and has no credential",
));
}
self.resolve_scopes(extra).await
}
async fn resolve_scopes(&self, extra: &[String]) -> Result<Credential> {
let inner = &self.inner;
let mut want = inner.meta.scopes.clone();
for scope in extra {
if !want.contains(scope) {
want.push(scope.clone());
}
}
let mut state = inner.state.lock().await;
if let Some(credential) = &state.credential
&& want.iter().all(|scope| state.requested.contains(scope))
{
return Ok(credential.clone());
}
let mut requested = state.requested.clone();
for scope in &want {
if !requested.contains(scope) {
requested.push(scope.clone());
}
}
let mut meta = inner.meta.clone();
meta.set_scopes(requested.clone());
let req = CredentialRequest::new(&inner.env, &inner.command_path, &inner.tier, &meta);
let credential = inner
.auth
.get_credential_for(&inner.provider, &req)
.await
.map_err(|source| auth_resolution_error(&inner.provider, source))?;
if let Some(previous) = &state.credential {
let previous_key = identity_key(previous);
let new_key = identity_key(&credential);
if !previous_key.is_empty() && !new_key.is_empty() && previous_key != new_key {
return Err(CliCoreError::message(format!(
"scope step-up authenticated as a different identity \
(was {previous_key:?}, now {new_key:?}); aborting"
)));
}
}
state.credential = Some(credential.clone());
state.requested = requested;
drop(inner.cell.set(credential.clone()));
Ok(credential)
}
pub async fn try_resolve(&self) -> Result<Option<Credential>> {
if self.inner.no_auth {
return Ok(None);
}
self.resolve().await.map(Some)
}
#[must_use]
pub fn peek(&self) -> Option<&Credential> {
self.inner.cell.get()
}
}
fn auth_resolution_error(provider: &str, source: CliCoreError) -> CliCoreError {
match source {
auth @ (CliCoreError::MissingAuthProvider(_) | CliCoreError::AuthProvider { .. }) => auth,
other => CliCoreError::AuthProvider {
provider: provider.to_owned(),
source: Box::new(other),
},
}
}
fn identity_key(credential: &Credential) -> &str {
if credential.sub.is_empty() {
credential.identity.as_str()
} else {
credential.sub.as_str()
}
}
#[async_trait]
pub trait Authorizer: Send + Sync + std::fmt::Debug {
async fn authorize(
&self,
command_path: &str,
args: &ValueMap,
credential: &CredentialResolver,
reason: &str,
tier: Tier,
) -> Result<()>;
}
#[async_trait]
pub trait Auditor: Send + Sync + std::fmt::Debug {
async fn append(
&self,
command_path: &str,
args: &ValueMap,
identity: &str,
result: &str,
reason: &str,
) -> Result<()>;
}
#[async_trait]
pub trait ActivityEmitter: Send + Sync + std::fmt::Debug {
async fn emit(&self, event: ActivityEvent) -> Result<()>;
}
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
pub struct ActivityEvent {
pub timestamp: String,
pub app: String,
pub command: String,
pub env: String,
pub backend: String,
pub identity: String,
pub sub: String,
pub account_type: String,
pub status: String,
pub error: String,
pub reason: String,
pub args: ValueMap,
pub duration_ms: i64,
pub meta: ValueMap,
}
#[derive(Clone, Debug, Default)]
pub struct Middleware {
pub authz: Option<Arc<dyn Authorizer>>,
pub auth: Dispatcher,
pub auditor: Option<Arc<dyn Auditor>>,
pub activity: Option<Arc<dyn ActivityEmitter>>,
pub app_id: String,
pub default_auth_provider: String,
pub output_format: String,
pub env: String,
pub verbose: String,
pub dry_run: bool,
pub fields: String,
pub filter: String,
pub expr: String,
pub limit: i64,
pub offset: i64,
pub reason: String,
pub schema: bool,
pub timeout: Option<Duration>,
pub debug: String,
pub search: String,
pub schema_registry: SchemaRegistry,
pub human_views: HumanViewRegistry,
pub config: Arc<crate::config::ConfigFile>,
pub environments: Option<Arc<crate::environments::Environments>>,
}
#[derive(Clone, Debug, PartialEq)]
pub struct MiddlewareOutput {
pub envelope: Envelope,
pub rendered: String,
pub exit_code: i32,
}
#[derive(Clone, Debug, PartialEq)]
pub struct MiddlewareRequest<'request> {
pub meta: CommandMeta,
pub command_path: &'request str,
pub system: &'request str,
pub user_args: ValueMap,
pub args: ValueMap,
pub default_fields: &'request str,
pub view_id: Option<&'request str>,
pub auth: AuthRequirement,
}
impl Middleware {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub async fn run<F, Fut, Output>(
&self,
request: MiddlewareRequest<'_>,
command: F,
) -> Result<MiddlewareOutput>
where
F: FnOnce(CredentialResolver) -> Fut + Send,
Fut: Future<Output = Result<Output>> + Send,
Output: Into<CommandResult>,
{
let start = Instant::now();
let MiddlewareRequest {
meta,
command_path,
system,
user_args,
mut args,
default_fields,
view_id,
auth,
} = request;
let no_auth = auth.is_none();
let command_system = effective_request_system(system, command_path);
if !no_auth && !self.env.is_empty() && !args.contains_key("env") {
args.insert("env".to_owned(), Value::String(self.env.clone()));
}
let provider_name = meta
.provider()
.filter(|provider| !provider.is_empty())
.unwrap_or(&self.default_auth_provider)
.to_owned();
let resolved_env = meta.fixed_env().unwrap_or(&self.env).to_owned();
let tier_text = meta
.auth_metadata
.get("tier")
.map_or("", String::as_str)
.to_owned();
let resolver = CredentialResolver::new(
self.auth.clone(),
provider_name.clone(),
resolved_env,
command_path.to_owned(),
tier_text,
no_auth,
meta.clone(),
);
if no_auth
&& let Some(output) =
self.render_schema_if_requested(command_path, start, &user_args, &args, "")?
{
return Ok(output);
}
if let Some(authz) = &self.authz
&& let Err(err) = authz
.authorize(command_path, &args, &resolver, &self.reason, meta.tier())
.await
{
let identity = resolver.peek().map_or("", |cred| cred.identity.as_str());
let had_auth_error = err.is_auth();
let result_tag = if had_auth_error {
"auth-error"
} else {
"denied"
};
let backend = if had_auth_error {
provider_name.as_str()
} else {
command_path
};
self.write_audit(command_path, &args, identity, result_tag)
.await;
self.emit_activity(
command_path,
&args,
resolver.peek(),
result_tag,
backend,
&err.to_string(),
start,
)
.await;
return self.render_error(&err, command_path, start, &user_args, &args, identity);
}
let schema_identity = resolver.peek().map_or("", |cred| cred.identity.as_str());
if let Some(output) = self.render_schema_if_requested(
command_path,
start,
&user_args,
&args,
schema_identity,
)? {
return Ok(output);
}
if self.dry_run && meta.dry_run_prompt {
let identity = resolver.peek().map_or("", |cred| cred.identity.as_str());
self.write_audit(command_path, &args, identity, "dry-run")
.await;
self.emit_activity(
command_path,
&args,
resolver.peek(),
"dry-run",
command_path,
"",
start,
)
.await;
let envelope = Envelope::success(
json!({
"command": command_path,
"action": "dry-run: would execute",
}),
command_path,
)
.with_dry_run();
return self.render_envelope(
envelope,
"",
"",
command_path,
start,
&user_args,
&args,
identity,
);
}
if auth.is_required()
&& let Err(err) = resolver.resolve().await
{
self.write_audit(command_path, &args, "", "auth-error")
.await;
self.emit_activity(
command_path,
&args,
resolver.peek(),
"auth-error",
provider_name.as_str(),
&err.to_string(),
start,
)
.await;
return self.render_error(&err, command_path, start, &user_args, &args, "");
}
let result = match command(resolver.clone()).await {
Ok(result) => result.into(),
Err(err) => {
let identity = resolver.peek().map_or("", |cred| cred.identity.as_str());
let (result_tag, error_system, activity_backend) = if err.is_auth() {
("auth-error", command_path, provider_name.as_str())
} else {
let system = err.system().unwrap_or(&command_system);
("error", system, system)
};
self.write_audit(command_path, &args, identity, result_tag)
.await;
self.emit_activity(
command_path,
&args,
resolver.peek(),
result_tag,
activity_backend,
&err.to_string(),
start,
)
.await;
return self.render_error(&err, error_system, start, &user_args, &args, identity);
}
};
let identity = resolver.peek().map_or("", |cred| cred.identity.as_str());
self.write_audit(command_path, &args, identity, "ok").await;
self.emit_activity(
command_path,
&args,
resolver.peek(),
"ok",
&command_system,
"",
start,
)
.await;
let CommandResult { data, metadata } = result;
self.render_envelope(
Envelope::success(data, command_system).with_next_actions(metadata.next_actions),
default_fields,
view_id.unwrap_or_default(),
command_path,
start,
&user_args,
&args,
identity,
)
}
#[doc(hidden)]
pub async fn run_no_auth<F, Fut>(
&self,
meta: CommandMeta,
command_path: &str,
user_args: ValueMap,
args: ValueMap,
default_fields: &str,
command: F,
) -> Result<MiddlewareOutput>
where
F: FnOnce() -> Fut + Send,
Fut: Future<Output = Result<CommandResult>> + Send,
{
self.run(
MiddlewareRequest {
meta,
command_path,
system: fallback_system(command_path),
user_args,
args,
default_fields,
view_id: None,
auth: AuthRequirement::None,
},
async move |_resolver| command().await,
)
.await
}
async fn write_audit(&self, command_path: &str, args: &ValueMap, identity: &str, result: &str) {
if let Some(auditor) = &self.auditor
&& let Err(err) = auditor
.append(command_path, args, identity, result, &self.reason)
.await
{
tracing::warn!(command = command_path, error = %err, "audit log write failed");
}
}
#[allow(clippy::too_many_arguments)]
async fn emit_activity(
&self,
command_path: &str,
args: &ValueMap,
credential: Option<&Credential>,
result: &str,
backend: &str,
error: &str,
start: Instant,
) {
let Some(activity) = &self.activity else {
return;
};
let (identity, sub, account_type) = credential.map_or_else(
|| (String::new(), String::new(), String::new()),
|credential| {
(
credential.identity.clone(),
credential.sub.clone(),
credential.account_type.clone(),
)
},
);
let duration_ms = i64::try_from(start.elapsed().as_millis()).unwrap_or(i64::MAX);
let event = ActivityEvent {
timestamp: chrono::Utc::now().to_rfc3339_opts(chrono::SecondsFormat::Secs, true),
app: self.app_id.clone(),
command: command_path.to_owned(),
env: self.env.clone(),
backend: backend.to_owned(),
identity,
sub,
account_type,
status: result.to_owned(),
error: error.to_owned(),
reason: self.reason.clone(),
args: args.clone(),
duration_ms,
meta: ValueMap::new(),
};
if let Err(err) = activity.emit(event).await {
tracing::warn!(command = command_path, error = %err, "activity emit failed");
}
}
fn render_schema_if_requested(
&self,
command_path: &str,
start: Instant,
user_args: &ValueMap,
effective_args: &ValueMap,
identity: &str,
) -> Result<Option<MiddlewareOutput>> {
if self.schema {
let envelope = match self.schema_registry.get_by_path(command_path) {
Some(schema) => Envelope::success(schema, self.app_id.clone()),
None => Envelope::success(
crate::output::no_schema_response(command_path),
self.app_id.clone(),
),
};
return self
.render_envelope(
envelope,
"",
"",
command_path,
start,
user_args,
effective_args,
identity,
)
.map(Some);
}
Ok(None)
}
#[allow(clippy::too_many_arguments)]
fn render_envelope(
&self,
mut envelope: Envelope,
default_fields: &str,
view_id: &str,
command_path: &str,
start: Instant,
user_args: &ValueMap,
effective_args: &ValueMap,
identity: &str,
) -> Result<MiddlewareOutput> {
if !is_valid_output_format(&self.output_format) {
let err = CliCoreError::InvalidOutputFormat(self.output_format.clone());
return self.render_error(
&err,
&self.app_id,
start,
user_args,
effective_args,
identity,
);
}
let output_format = self.output_format.parse::<OutputFormat>()?;
let effective_fields = if self.fields.is_empty() {
default_fields
} else {
self.fields.as_str()
};
let human_view = output_format == OutputFormat::Human && self.human_views.has_view(view_id);
let projection_fields = if human_view { "" } else { effective_fields };
if let Some(data) = &mut envelope.data {
let pagination = apply_pipeline(
data,
&PipelineOpts {
filter: self.filter.clone(),
limit: self.limit,
offset: self.offset,
expr: self.expr.clone(),
fields: projection_fields.to_owned(),
},
)?;
if let Some(pagination) = pagination
&& let Some(metadata) = &mut envelope.metadata
{
metadata.pagination = Some(pagination);
}
}
envelope.with_context(
command_path,
&self.env,
identity,
start.elapsed(),
Some(Value::Object(user_args.clone())),
Some(Value::Object(effective_args.clone())),
);
let prepared = envelope.prepare_for_render(&self.verbose);
let rendered = if output_format == OutputFormat::Human {
render_human_with_registry_selected(
&prepared,
&self.human_views,
view_id,
effective_fields,
)
} else {
crate::output::render(output_format, &prepared)?
};
Ok(MiddlewareOutput {
envelope: prepared,
rendered,
exit_code: 0,
})
}
fn render_error(
&self,
err: &(dyn std::error::Error + 'static),
system: &str,
start: Instant,
user_args: &ValueMap,
effective_args: &ValueMap,
identity: &str,
) -> Result<MiddlewareOutput> {
let mut envelope = build_error_envelope(err, system);
envelope.with_context(
"",
&self.env,
identity,
start.elapsed(),
Some(Value::Object(user_args.clone())),
Some(Value::Object(effective_args.clone())),
);
let prepared = envelope.prepare_for_render(&self.verbose);
let rendered = crate::output::render_format(&self.output_format, &prepared)?;
Ok(MiddlewareOutput {
envelope: prepared,
rendered,
exit_code: exit_code_for_error(err),
})
}
}
#[must_use]
pub fn value_map(entries: impl IntoIterator<Item = (impl Into<String>, Value)>) -> ValueMap {
entries
.into_iter()
.map(|(key, value)| (key.into(), value))
.collect()
}
fn effective_request_system(system: &str, command_path: &str) -> String {
if system.is_empty() {
return fallback_system(command_path).to_owned();
}
system.to_owned()
}
fn fallback_system(command_path: &str) -> &str {
command_path
.split_once(':')
.map_or(command_path, |(system, _)| system)
}
impl From<CliCoreError> for Value {
fn from(error: CliCoreError) -> Self {
Value::String(error.to_string())
}
}
#[cfg(test)]
mod env_wire_tests {
use super::*;
#[test]
fn middleware_carries_optional_environments() {
use std::sync::Arc;
let mut mw = Middleware::new();
assert!(mw.environments.is_none());
mw.environments = Some(Arc::new(crate::environments::Environments::new("prod")));
assert_eq!(
mw.environments
.as_ref()
.map(|envs| envs.default_env().to_owned()),
Some("prod".to_owned())
);
}
}