use crate::evaluation::evaluate_guard;
use crate::types::{
ClientOptions, GetGuardsRequest, Guard, GuardExecutionPerformance, InitError, Options,
Properties, ProtectedContext, SendUnadoptedGuardsRequest, Signal, SignalPerformance,
TraceContext,
};
use opentelemetry::trace::TraceContextExt;
use reqwest::Client as HttpClient;
use serde::Deserialize;
use std::cell::RefCell;
use std::collections::{HashMap, HashSet};
use std::panic::Location;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::{Arc, Mutex, RwLock};
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
use tokio::runtime::Handle;
use tokio::sync::Notify;
use tokio::task::JoinHandle;
use tokio::time;
thread_local! {
static EXECUTION_STATE: RefCell<Option<ExecutionState>> = const { RefCell::new(None) };
}
static SIGNAL_COUNTER: AtomicU64 = AtomicU64::new(0);
const PUBLIC_BUNDLE_KEY: &str = "";
#[derive(Clone)]
struct ExecutionState {
execution_id: String,
sequence_number: i64,
last_signal_id: Option<String>,
}
struct SignalMetadata {
signal_id: String,
execution_id: String,
parent_signal_id: Option<String>,
sequence_number: i64,
}
struct GuardEvaluation {
result: bool,
guard: Option<Guard>,
props: Option<Properties>,
signal: Option<Signal>,
}
struct BufferedSignalInput {
guard_name: String,
result: bool,
props: Properties,
metadata: SignalMetadata,
callsite_id: String,
kind: String,
measurement: Option<SignalPerformance>,
}
#[derive(Clone)]
struct GuardBundle {
key: String,
guards: HashMap<String, Guard>,
ready: bool,
etag: String,
protected_context: Option<ProtectedContext>,
refresh_rate_seconds: u64,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct GuardsResponse {
guards: Vec<Guard>,
#[serde(default)]
refresh_rate_seconds: u64,
#[serde(default)]
etag: String,
}
#[derive(Clone)]
struct RateLimitEntry {
window_start: Instant,
count: i32,
}
struct Inner {
project_client_key_id: String,
opts: ClientOptions,
runtime: Handle,
http: HttpClient,
bundles: RwLock<HashMap<String, GuardBundle>>,
current_refresh_rate_seconds: Mutex<u64>,
refresh_notify: Notify,
signal_buffer: Mutex<Vec<Signal>>,
dropped_signals_pending: Mutex<i32>,
pending_unadopted_guards: Mutex<HashSet<String>>,
reported_unadopted_guards: Mutex<HashSet<String>>,
rate_limit_state: Mutex<HashMap<String, RateLimitEntry>>,
}
impl Inner {
fn log(&self, message: impl AsRef<str>) {
if self.opts.quiet {
return;
}
eprintln!("[liteguard] {}", message.as_ref());
}
}
#[derive(Clone)]
pub struct Scope {
inner: Arc<Inner>,
properties: Properties,
bundle_key: String,
protected_context: Option<ProtectedContext>,
}
impl Scope {
pub fn properties(&self) -> &Properties {
&self.properties
}
pub fn protected_context(&self) -> Option<&ProtectedContext> {
self.protected_context.as_ref()
}
pub fn with_properties(&self, properties: Properties) -> Scope {
let mut merged = self.properties.clone();
for (key, value) in properties.0 {
merged.insert(key, value);
}
Scope {
inner: Arc::clone(&self.inner),
properties: merged,
bundle_key: self.bundle_key.clone(),
protected_context: self.protected_context.clone(),
}
}
pub fn add_properties(&self, properties: Properties) -> Scope {
self.with_properties(properties)
}
pub fn clear_properties(&self, names: &[&str]) -> Scope {
let mut cleared = self.properties.clone();
for name in names {
cleared.remove(name);
}
Scope {
inner: Arc::clone(&self.inner),
properties: cleared,
bundle_key: self.bundle_key.clone(),
protected_context: self.protected_context.clone(),
}
}
pub fn reset_properties(&self) -> Scope {
Scope {
inner: Arc::clone(&self.inner),
properties: Properties::new(),
bundle_key: self.bundle_key.clone(),
protected_context: self.protected_context.clone(),
}
}
pub fn belongs_to(&self, client: &Client) -> bool {
Arc::ptr_eq(&self.inner, &client.inner)
}
pub async fn bind_protected_context(
&self,
protected_context: ProtectedContext,
) -> Result<Scope, reqwest::Error> {
let bundle_key =
Client::ensure_bundle_for_protected_context(&self.inner, protected_context.clone())
.await?;
Ok(Scope {
inner: Arc::clone(&self.inner),
properties: self.properties.clone(),
bundle_key,
protected_context: Some(protected_context),
})
}
pub async fn clear_protected_context(&self) -> Result<Scope, reqwest::Error> {
Client::ensure_public_bundle_ready(&self.inner).await?;
Ok(Scope {
inner: Arc::clone(&self.inner),
properties: self.properties.clone(),
bundle_key: PUBLIC_BUNDLE_KEY.to_owned(),
protected_context: None,
})
}
#[track_caller]
pub fn is_open(&self, name: &str) -> bool {
self.is_open_with_options(name, &Options::default())
}
#[track_caller]
pub fn is_open_with_options(&self, name: &str, options: &Options) -> bool {
Client::evaluate_guard_in_scope(&self.inner, self, name, options, true, Location::caller())
.result
}
#[track_caller]
pub fn peek_is_open(&self, name: &str) -> bool {
self.peek_is_open_with_options(name, &Options::default())
}
#[track_caller]
pub fn peek_is_open_with_options(&self, name: &str, options: &Options) -> bool {
Client::evaluate_guard_in_scope(&self.inner, self, name, options, false, Location::caller())
.result
}
#[track_caller]
pub fn execute_if_open<F, R>(&self, name: &str, f: F) -> Option<R>
where
F: FnOnce() -> R,
{
self.execute_if_open_with_options(name, &Options::default(), f)
}
#[track_caller]
pub fn execute_if_open_with_options<F, R>(
&self,
name: &str,
options: &Options,
f: F,
) -> Option<R>
where
F: FnOnce() -> R,
{
Client::with_execution_inner(&self.inner, || {
let evaluation = Client::evaluate_guard_in_scope(
&self.inner,
self,
name,
options,
true,
Location::caller(),
);
if !evaluation.result {
return None;
}
let guard_check_signal = evaluation.signal;
if guard_check_signal.is_none() {
return Some(f());
}
let guard = evaluation
.guard
.as_ref()
.expect("guard should exist when signal exists");
let props = evaluation
.props
.expect("props should exist when signal exists");
let measurement_enabled = Client::is_measurement_enabled(&self.inner, guard, options);
let started_at = Instant::now();
let guard_check_signal_id = guard_check_signal
.as_ref()
.map(|signal| signal.signal_id.clone())
.expect("signal id should exist");
let callsite_id = format!(
"{}:{}",
Location::caller().file(),
Location::caller().line()
);
match std::panic::catch_unwind(std::panic::AssertUnwindSafe(f)) {
Ok(value) => {
Client::buffer_signal_inner(
&self.inner,
BufferedSignalInput {
guard_name: name.to_owned(),
result: true,
props,
metadata: Client::next_signal_metadata(Some(guard_check_signal_id)),
callsite_id,
kind: "guard_execution".into(),
measurement: if measurement_enabled {
Some(Client::capture_guard_execution_measurement(
started_at, true, None,
))
} else {
None
},
},
);
Some(value)
}
Err(payload) => {
Client::buffer_signal_inner(
&self.inner,
BufferedSignalInput {
guard_name: name.to_owned(),
result: true,
props,
metadata: Client::next_signal_metadata(Some(guard_check_signal_id)),
callsite_id,
kind: "guard_execution".into(),
measurement: if measurement_enabled {
Some(Client::capture_guard_execution_measurement(
started_at,
false,
Some(Client::panic_error_class(payload.as_ref())),
))
} else {
None
},
},
);
std::panic::resume_unwind(payload)
}
}
})
}
#[track_caller]
pub fn try_execute_if_open<F, R, E>(&self, name: &str, f: F) -> Result<Option<R>, E>
where
F: FnOnce() -> Result<R, E>,
{
self.try_execute_if_open_with_options(name, &Options::default(), f)
}
#[track_caller]
pub fn try_execute_if_open_with_options<F, R, E>(
&self,
name: &str,
options: &Options,
f: F,
) -> Result<Option<R>, E>
where
F: FnOnce() -> Result<R, E>,
{
Client::with_execution_inner(&self.inner, || {
let evaluation = Client::evaluate_guard_in_scope(
&self.inner,
self,
name,
options,
true,
Location::caller(),
);
if !evaluation.result {
return Ok(None);
}
let guard_check_signal = evaluation.signal;
if guard_check_signal.is_none() {
return f().map(Some);
}
let guard = evaluation
.guard
.as_ref()
.expect("guard should exist when signal exists");
let props = evaluation
.props
.expect("props should exist when signal exists");
let measurement_enabled = Client::is_measurement_enabled(&self.inner, guard, options);
let started_at = Instant::now();
let guard_check_signal_id = guard_check_signal
.as_ref()
.map(|signal| signal.signal_id.clone())
.expect("signal id should exist");
let callsite_id = format!(
"{}:{}",
Location::caller().file(),
Location::caller().line()
);
match std::panic::catch_unwind(std::panic::AssertUnwindSafe(f)) {
Ok(Ok(value)) => {
Client::buffer_signal_inner(
&self.inner,
BufferedSignalInput {
guard_name: name.to_owned(),
result: true,
props,
metadata: Client::next_signal_metadata(Some(guard_check_signal_id)),
callsite_id,
kind: "guard_execution".into(),
measurement: if measurement_enabled {
Some(Client::capture_guard_execution_measurement(
started_at, true, None,
))
} else {
None
},
},
);
Ok(Some(value))
}
Ok(Err(error)) => {
Client::buffer_signal_inner(
&self.inner,
BufferedSignalInput {
guard_name: name.to_owned(),
result: true,
props,
metadata: Client::next_signal_metadata(Some(guard_check_signal_id)),
callsite_id,
kind: "guard_execution".into(),
measurement: if measurement_enabled {
Some(Client::capture_guard_execution_measurement(
started_at,
false,
Some(std::any::type_name::<E>().to_owned()),
))
} else {
None
},
},
);
Err(error)
}
Err(payload) => {
Client::buffer_signal_inner(
&self.inner,
BufferedSignalInput {
guard_name: name.to_owned(),
result: true,
props,
metadata: Client::next_signal_metadata(Some(guard_check_signal_id)),
callsite_id,
kind: "guard_execution".into(),
measurement: if measurement_enabled {
Some(Client::capture_guard_execution_measurement(
started_at,
false,
Some(Client::panic_error_class(payload.as_ref())),
))
} else {
None
},
},
);
std::panic::resume_unwind(payload)
}
}
})
}
}
pub struct Client {
inner: Arc<Inner>,
refresh_task: JoinHandle<()>,
flush_task: JoinHandle<()>,
is_shutdown: AtomicBool,
}
impl Client {
pub async fn init(
project_client_key_id: impl Into<String>,
opts: ClientOptions,
) -> Result<Client, InitError> {
let project_client_key_id = project_client_key_id.into();
let opts = opts.with_defaults();
let configured_refresh_rate_seconds = opts.refresh_rate_seconds;
let http = HttpClient::builder()
.timeout(Duration::from_secs(opts.http_timeout_seconds))
.build()
.map_err(|e| InitError(format!("failed to build HTTP client: {e}")))?;
let mut bundles = HashMap::new();
bundles.insert(
PUBLIC_BUNDLE_KEY.to_owned(),
Self::create_empty_bundle(
PUBLIC_BUNDLE_KEY.to_owned(),
None,
configured_refresh_rate_seconds,
),
);
let inner = Arc::new(Inner {
project_client_key_id,
opts,
runtime: Handle::current(),
http,
bundles: RwLock::new(bundles),
current_refresh_rate_seconds: Mutex::new(configured_refresh_rate_seconds),
refresh_notify: Notify::new(),
signal_buffer: Mutex::new(Vec::new()),
dropped_signals_pending: Mutex::new(0),
pending_unadopted_guards: Mutex::new(HashSet::new()),
reported_unadopted_guards: Mutex::new(HashSet::new()),
rate_limit_state: Mutex::new(HashMap::new()),
});
let _ = Self::fetch_guards_for_bundle(&inner, PUBLIC_BUNDLE_KEY).await;
let refresh_task = {
let inner = Arc::clone(&inner);
tokio::spawn(async move {
loop {
let interval = *inner.current_refresh_rate_seconds.lock().unwrap();
if time::timeout(
Duration::from_secs(interval),
inner.refresh_notify.notified(),
)
.await
.is_ok()
{
continue;
}
let bundle_keys = {
let bundles = inner.bundles.read().unwrap();
let mut keys = bundles.keys().cloned().collect::<Vec<_>>();
keys.sort();
keys
};
for bundle_key in bundle_keys {
let _ = Client::fetch_guards_for_bundle(&inner, &bundle_key).await;
}
}
})
};
let flush_task = {
let inner = Arc::clone(&inner);
let interval = inner.opts.flush_rate_seconds;
tokio::spawn(async move {
let mut ticker = time::interval(Duration::from_secs(interval));
ticker.tick().await;
loop {
ticker.tick().await;
let _ = Client::flush_signals_inner(&inner).await;
}
})
};
Ok(Client {
inner,
refresh_task,
flush_task,
is_shutdown: AtomicBool::new(false),
})
}
pub fn create_scope(&self, properties: Properties) -> Scope {
Scope {
inner: Arc::clone(&self.inner),
properties,
bundle_key: PUBLIC_BUNDLE_KEY.to_owned(),
protected_context: None,
}
}
fn public_scope(&self) -> Scope {
self.create_scope(Properties::new())
}
#[track_caller]
pub fn is_open(&self, name: &str) -> bool {
self.is_open_with_options(name, &Options::default())
}
#[track_caller]
pub fn is_open_with_options(&self, name: &str, options: &Options) -> bool {
let scope = self.public_scope();
Self::evaluate_guard_in_scope(&self.inner, &scope, name, options, true, Location::caller())
.result
}
#[track_caller]
pub fn peek_is_open(&self, name: &str) -> bool {
self.peek_is_open_with_options(name, &Options::default())
}
#[track_caller]
pub fn peek_is_open_with_options(&self, name: &str, options: &Options) -> bool {
let scope = self.public_scope();
Self::evaluate_guard_in_scope(
&self.inner,
&scope,
name,
options,
false,
Location::caller(),
)
.result
}
#[track_caller]
pub fn execute_if_open<F, R>(&self, name: &str, f: F) -> Option<R>
where
F: FnOnce() -> R,
{
self.execute_if_open_with_options(name, &Options::default(), f)
}
#[track_caller]
pub fn execute_if_open_with_options<F, R>(
&self,
name: &str,
options: &Options,
f: F,
) -> Option<R>
where
F: FnOnce() -> R,
{
self.public_scope()
.execute_if_open_with_options(name, options, f)
}
#[track_caller]
pub fn try_execute_if_open<F, R, E>(&self, name: &str, f: F) -> Result<Option<R>, E>
where
F: FnOnce() -> Result<R, E>,
{
self.try_execute_if_open_with_options(name, &Options::default(), f)
}
#[track_caller]
pub fn try_execute_if_open_with_options<F, R, E>(
&self,
name: &str,
options: &Options,
f: F,
) -> Result<Option<R>, E>
where
F: FnOnce() -> Result<R, E>,
{
self.public_scope()
.try_execute_if_open_with_options(name, options, f)
}
pub fn with_execution<F, R>(&self, f: F) -> R
where
F: FnOnce() -> R,
{
Self::with_execution_inner(&self.inner, f)
}
fn with_execution_inner<F, R>(_inner: &Arc<Inner>, f: F) -> R
where
F: FnOnce() -> R,
{
EXECUTION_STATE.with(|slot| {
if slot.borrow().is_some() {
return f();
}
let previous = slot.replace(Some(ExecutionState {
execution_id: next_signal_id(),
sequence_number: 0,
last_signal_id: None,
}));
debug_assert!(previous.is_none());
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(f));
slot.replace(None);
match result {
Ok(value) => value,
Err(payload) => std::panic::resume_unwind(payload),
}
})
}
pub async fn flush(&self) -> Result<(), reqwest::Error> {
Self::flush_signals_inner(&self.inner).await
}
pub async fn shutdown(&self) -> Result<(), reqwest::Error> {
self.stop_background_tasks();
self.flush().await
}
pub fn set_guards_for_testing(&self, guards: Vec<Guard>) {
{
let mut bundles = self.inner.bundles.write().unwrap();
bundles.insert(
PUBLIC_BUNDLE_KEY.to_owned(),
GuardBundle {
key: PUBLIC_BUNDLE_KEY.to_owned(),
guards: guards
.into_iter()
.map(|guard| (guard.name.clone(), guard))
.collect(),
ready: true,
etag: String::new(),
protected_context: None,
refresh_rate_seconds: self.inner.opts.refresh_rate_seconds,
},
);
*self.inner.current_refresh_rate_seconds.lock().unwrap() =
Self::recompute_refresh_interval(&bundles, self.inner.opts.refresh_rate_seconds);
}
}
pub fn known_bundle_count_for_testing(&self) -> usize {
self.inner.bundles.read().unwrap().len()
}
pub fn current_refresh_rate_for_testing(&self) -> u64 {
*self.inner.current_refresh_rate_seconds.lock().unwrap()
}
pub fn pending_unadopted_guards_for_testing(&self) -> Vec<String> {
let mut guards = self
.inner
.pending_unadopted_guards
.lock()
.unwrap()
.iter()
.cloned()
.collect::<Vec<_>>();
guards.sort();
guards
}
fn create_empty_bundle(
key: String,
protected_context: Option<ProtectedContext>,
refresh_rate_seconds: u64,
) -> GuardBundle {
GuardBundle {
key,
guards: HashMap::new(),
ready: false,
etag: String::new(),
protected_context,
refresh_rate_seconds,
}
}
async fn ensure_public_bundle_ready(inner: &Arc<Inner>) -> Result<(), reqwest::Error> {
let ready = inner
.bundles
.read()
.unwrap()
.get(PUBLIC_BUNDLE_KEY)
.map(|bundle| bundle.ready)
.unwrap_or(false);
if ready {
return Ok(());
}
Self::fetch_guards_for_bundle(inner, PUBLIC_BUNDLE_KEY).await
}
async fn ensure_bundle_for_protected_context(
inner: &Arc<Inner>,
protected_context: ProtectedContext,
) -> Result<String, reqwest::Error> {
let bundle_key = Self::protected_context_cache_key(&protected_context);
let ready = {
let mut bundles = inner.bundles.write().unwrap();
bundles
.entry(bundle_key.clone())
.or_insert_with(|| {
Self::create_empty_bundle(
bundle_key.clone(),
Some(protected_context.clone()),
inner.opts.refresh_rate_seconds,
)
})
.ready
};
if ready {
return Ok(bundle_key);
}
Self::fetch_guards_for_bundle(inner, &bundle_key).await?;
Ok(bundle_key)
}
fn protected_context_cache_key(protected_context: &ProtectedContext) -> String {
let mut keys = protected_context
.properties
.keys()
.cloned()
.collect::<Vec<_>>();
keys.sort();
let mut parts = vec![protected_context.signature.clone(), String::new()];
for key in keys {
let value = protected_context
.properties
.get(&key)
.cloned()
.unwrap_or_default();
parts.push(format!("{key}={value}"));
}
parts.join("\0")
}
fn recompute_refresh_interval(bundles: &HashMap<String, GuardBundle>, default: u64) -> u64 {
bundles
.values()
.map(|bundle| bundle.refresh_rate_seconds)
.filter(|rate| *rate > 0)
.min()
.unwrap_or(default)
}
async fn fetch_guards_for_bundle(
inner: &Arc<Inner>,
bundle_key: &str,
) -> Result<(), reqwest::Error> {
let bundle = {
let bundles = inner.bundles.read().unwrap();
bundles.get(bundle_key).cloned().unwrap_or_else(|| {
Self::create_empty_bundle(
bundle_key.to_owned(),
None,
inner.opts.refresh_rate_seconds,
)
})
};
let url = format!(
"{}/api/v1/guards",
inner.opts.backend_url.trim_end_matches('/')
);
let mut req = inner
.http
.post(&url)
.bearer_auth(&inner.project_client_key_id)
.json(&GetGuardsRequest {
project_client_key_id: inner.project_client_key_id.clone(),
environment: inner.opts.environment.clone().unwrap_or_default(),
protected_context: bundle.protected_context.clone(),
});
if !bundle.etag.is_empty() {
req = req.header("If-None-Match", &bundle.etag);
}
let response = req.send().await?;
if response.status().as_u16() == 304 {
return Ok(());
}
if !response.status().is_success() {
inner.log(format!(
"guard fetch failed for bundle {:?}: HTTP {}",
bundle_key,
response.status().as_u16()
));
return Ok(());
}
let body: GuardsResponse = response.json().await?;
let effective_refresh_rate = if body.refresh_rate_seconds == 0 {
inner.opts.refresh_rate_seconds
} else {
body.refresh_rate_seconds
};
{
let mut bundles = inner.bundles.write().unwrap();
bundles.insert(
bundle.key.clone(),
GuardBundle {
key: bundle.key,
guards: body
.guards
.into_iter()
.map(|guard| (guard.name.clone(), guard))
.collect(),
ready: true,
etag: body.etag,
protected_context: bundle.protected_context,
refresh_rate_seconds: effective_refresh_rate,
},
);
let next_refresh =
Self::recompute_refresh_interval(&bundles, inner.opts.refresh_rate_seconds);
let mut current = inner.current_refresh_rate_seconds.lock().unwrap();
if *current != next_refresh {
*current = next_refresh;
inner.refresh_notify.notify_one();
}
}
Ok(())
}
fn evaluate_guard_in_scope(
inner: &Arc<Inner>,
scope: &Scope,
name: &str,
options: &Options,
emit_signal: bool,
caller: &'static Location<'static>,
) -> GuardEvaluation {
let bundle = {
let bundles = inner.bundles.read().unwrap();
bundles
.get(&scope.bundle_key)
.cloned()
.or_else(|| bundles.get(PUBLIC_BUNDLE_KEY).cloned())
.unwrap_or_else(|| {
Self::create_empty_bundle(
PUBLIC_BUNDLE_KEY.to_owned(),
None,
inner.opts.refresh_rate_seconds,
)
})
};
if !bundle.ready {
return GuardEvaluation {
result: options.fallback.unwrap_or(inner.opts.fallback),
guard: None,
props: None,
signal: None,
};
}
let guard = match bundle.guards.get(name) {
Some(guard) => guard.clone(),
None => {
Self::record_unadopted_guard(inner, name.to_owned());
return GuardEvaluation {
result: true,
guard: None,
props: None,
signal: None,
};
}
};
if !guard.adopted {
Self::record_unadopted_guard(inner, name.to_owned());
return GuardEvaluation {
result: true,
guard: Some(guard),
props: None,
signal: None,
};
}
let mut props = scope.properties.clone();
for (key, value) in &options.properties.0 {
props.insert(key.clone(), value.clone());
}
let mut result = evaluate_guard(&guard, &props);
if result && guard.rate_limit_per_minute > 0 {
result = if emit_signal {
Self::check_rate_limit(
inner,
name,
guard.rate_limit_per_minute,
&guard.rate_limit_properties,
&props,
)
} else {
Self::would_pass_rate_limit(
inner,
name,
guard.rate_limit_per_minute,
&guard.rate_limit_properties,
&props,
)
};
}
let signal = if emit_signal {
Some(Self::buffer_signal_inner(
inner,
BufferedSignalInput {
guard_name: name.to_owned(),
result,
props: props.clone(),
metadata: Self::next_signal_metadata(None),
callsite_id: format!("{}:{}", caller.file(), caller.line()),
kind: "guard_check".into(),
measurement: if Self::is_measurement_enabled(inner, &guard, options) {
Self::capture_guard_check_measurement()
} else {
None
},
},
))
} else {
None
};
GuardEvaluation {
result,
guard: Some(guard),
props: Some(props),
signal,
}
}
fn record_unadopted_guard(inner: &Arc<Inner>, name: String) {
let mut reported = inner.reported_unadopted_guards.lock().unwrap();
if !reported.insert(name.clone()) {
return;
}
drop(reported);
inner.pending_unadopted_guards.lock().unwrap().insert(name);
}
fn rate_limit_bucket_key(
name: &str,
rate_limit_properties: &[String],
props: &Properties,
) -> String {
if rate_limit_properties.is_empty() {
return name.to_owned();
}
let mut key = String::from(name);
for property_name in rate_limit_properties {
key.push('\0');
key.push_str(property_name);
key.push('=');
if let Some(value) = props.get(property_name) {
key.push_str(&Self::property_value_to_string(value));
}
}
key
}
fn property_value_to_string(value: &crate::types::PropertyValue) -> String {
match value {
crate::types::PropertyValue::Bool(v) => v.to_string(),
crate::types::PropertyValue::Number(v) => v.to_string(),
crate::types::PropertyValue::Text(v) => v.clone(),
}
}
fn check_rate_limit(
inner: &Arc<Inner>,
name: &str,
limit_per_minute: i32,
rate_limit_properties: &[String],
props: &Properties,
) -> bool {
let key = Self::rate_limit_bucket_key(name, rate_limit_properties, props);
let now = Instant::now();
let mut state = inner.rate_limit_state.lock().unwrap();
let entry = state.entry(key).or_insert_with(|| RateLimitEntry {
window_start: now,
count: 0,
});
if now.duration_since(entry.window_start) >= Duration::from_secs(60) {
entry.window_start = now;
entry.count = 0;
}
if entry.count >= limit_per_minute {
return false;
}
entry.count += 1;
true
}
fn would_pass_rate_limit(
inner: &Arc<Inner>,
name: &str,
limit_per_minute: i32,
rate_limit_properties: &[String],
props: &Properties,
) -> bool {
let key = Self::rate_limit_bucket_key(name, rate_limit_properties, props);
let now = Instant::now();
let state = inner.rate_limit_state.lock().unwrap();
match state.get(&key) {
None => true,
Some(entry) => {
if now.duration_since(entry.window_start) >= Duration::from_secs(60) {
true
} else {
entry.count < limit_per_minute
}
}
}
}
fn buffer_signal_inner(inner: &Arc<Inner>, input: BufferedSignalInput) -> Signal {
let ts = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis()
.min(i64::MAX as u128) as i64;
let dropped_signals_since_last = {
let mut dropped = inner.dropped_signals_pending.lock().unwrap();
let pending = *dropped;
*dropped = 0;
pending
};
let signal = Signal {
guard_name: input.guard_name,
result: input.result,
properties: input.props.0,
timestamp_ms: ts,
trace: Client::capture_trace_context(),
signal_id: input.metadata.signal_id,
execution_id: input.metadata.execution_id,
parent_signal_id: input.metadata.parent_signal_id,
sequence_number: input.metadata.sequence_number,
callsite_id: input.callsite_id,
kind: input.kind,
dropped_signals_since_last,
measurement: input.measurement,
};
let flush_size = inner.opts.flush_size;
let max_buffered_signals = flush_size * inner.opts.flush_buffer_multiplier;
let mut buffer = inner.signal_buffer.lock().unwrap();
if buffer.len() >= max_buffered_signals {
buffer.remove(0);
*inner.dropped_signals_pending.lock().unwrap() += 1;
}
buffer.push(signal.clone());
if buffer.len() >= flush_size {
drop(buffer);
let inner = Arc::clone(inner);
let runtime = inner.runtime.clone();
runtime.spawn(async move {
let _ = Client::flush_signals_inner(&inner).await;
});
}
signal
}
fn stop_background_tasks(&self) {
if self.is_shutdown.swap(true, Ordering::SeqCst) {
return;
}
self.refresh_task.abort();
self.flush_task.abort();
}
async fn flush_signals_inner(inner: &Arc<Inner>) -> Result<(), reqwest::Error> {
let batch = {
let mut buffer = inner.signal_buffer.lock().unwrap();
std::mem::take(&mut *buffer)
};
let unadopted_guard_names = Self::take_pending_unadopted_guards(inner);
if batch.is_empty() && unadopted_guard_names.is_empty() {
return Ok(());
}
let mut first_error = None;
if !batch.is_empty() {
if let Err(err) = Self::flush_signal_batch(inner, batch).await {
first_error = Some(err);
}
}
if !unadopted_guard_names.is_empty() {
if let Err(err) = Self::flush_unadopted_guard_batch(inner, unadopted_guard_names).await
{
if first_error.is_none() {
first_error = Some(err);
}
}
}
if let Some(err) = first_error {
return Err(err);
}
Ok(())
}
async fn flush_signal_batch(
inner: &Arc<Inner>,
batch: Vec<Signal>,
) -> Result<(), reqwest::Error> {
#[derive(serde::Serialize)]
#[serde(rename_all = "camelCase")]
struct Payload<'a> {
project_client_key_id: &'a str,
environment: String,
signals: &'a [Signal],
}
let payload = Payload {
project_client_key_id: &inner.project_client_key_id,
environment: inner.opts.environment.clone().unwrap_or_default(),
signals: &batch,
};
let url = format!(
"{}/api/v1/signals",
inner.opts.backend_url.trim_end_matches('/')
);
let result = inner
.http
.post(&url)
.bearer_auth(&inner.project_client_key_id)
.json(&payload)
.send()
.await
.and_then(|response| response.error_for_status());
if let Err(err) = result {
inner.log(format!(
"signal flush failed: {err} for {} signals",
batch.len()
));
let mut buffer = inner.signal_buffer.lock().unwrap();
let old = std::mem::take(&mut *buffer);
*buffer = batch;
buffer.extend(old);
let max_buffered_signals = inner.opts.flush_size * inner.opts.flush_buffer_multiplier;
if buffer.len() > max_buffered_signals {
*inner.dropped_signals_pending.lock().unwrap() +=
(buffer.len() - max_buffered_signals) as i32;
buffer.truncate(max_buffered_signals);
}
return Err(err);
}
Ok(())
}
async fn flush_unadopted_guard_batch(
inner: &Arc<Inner>,
guard_names: Vec<String>,
) -> Result<(), reqwest::Error> {
let payload = SendUnadoptedGuardsRequest {
project_client_key_id: inner.project_client_key_id.clone(),
environment: inner.opts.environment.clone().unwrap_or_default(),
guard_names: guard_names.clone(),
};
let url = format!(
"{}/api/v1/unadopted-guards",
inner.opts.backend_url.trim_end_matches('/')
);
let result = inner
.http
.post(&url)
.bearer_auth(&inner.project_client_key_id)
.json(&payload)
.send()
.await
.and_then(|response| response.error_for_status());
if let Err(err) = result {
inner.log(format!(
"unadopted guard flush failed: {err} for {} guards",
guard_names.len()
));
Self::requeue_pending_unadopted_guards(inner, guard_names);
return Err(err);
}
Ok(())
}
fn take_pending_unadopted_guards(inner: &Arc<Inner>) -> Vec<String> {
let mut pending = inner.pending_unadopted_guards.lock().unwrap();
if pending.is_empty() {
return Vec::new();
}
let mut names = pending.iter().cloned().collect::<Vec<_>>();
names.sort();
pending.clear();
names
}
fn requeue_pending_unadopted_guards(inner: &Arc<Inner>, guard_names: Vec<String>) {
if guard_names.is_empty() {
return;
}
inner
.pending_unadopted_guards
.lock()
.unwrap()
.extend(guard_names);
}
fn next_signal_metadata(parent_signal_id_override: Option<String>) -> SignalMetadata {
let signal_id = next_signal_id();
EXECUTION_STATE.with(|slot| {
let mut state = slot.borrow_mut();
if let Some(state) = state.as_mut() {
state.sequence_number += 1;
let metadata = SignalMetadata {
signal_id: signal_id.clone(),
execution_id: state.execution_id.clone(),
parent_signal_id: parent_signal_id_override
.or_else(|| state.last_signal_id.clone()),
sequence_number: state.sequence_number,
};
state.last_signal_id = Some(signal_id);
metadata
} else {
SignalMetadata {
signal_id,
execution_id: next_signal_id(),
parent_signal_id: None,
sequence_number: 1,
}
}
})
}
fn is_measurement_enabled(inner: &Arc<Inner>, guard: &Guard, options: &Options) -> bool {
if inner.opts.disable_measurement || options.disable_measurement {
return false;
}
!guard.disable_measurement.unwrap_or(false)
}
fn capture_guard_check_measurement() -> Option<SignalPerformance> {
None
}
fn capture_guard_execution_measurement(
started_at: Instant,
completed: bool,
error_class: Option<String>,
) -> SignalPerformance {
SignalPerformance {
guard_check: None,
guard_execution: Some(GuardExecutionPerformance {
duration_ns: started_at.elapsed().as_nanos().min(i64::MAX as u128) as i64,
rss_end_bytes: None,
heap_used_end_bytes: None,
heap_total_end_bytes: None,
cpu_time_end_ns: None,
gc_count_end: None,
thread_count_end: None,
completed,
error_class,
}),
}
}
fn panic_error_class(_payload: &(dyn std::any::Any + Send)) -> String {
"panic".into()
}
fn capture_trace_context() -> Option<TraceContext> {
let context = opentelemetry::Context::current();
let span = context.span();
let span_context = span.span_context();
if !span_context.is_valid() {
return None;
}
Some(TraceContext {
trace_id: span_context.trace_id().to_string(),
span_id: span_context.span_id().to_string(),
parent_span_id: String::new(),
})
}
}
fn next_signal_id() -> String {
let n = SIGNAL_COUNTER.fetch_add(1, Ordering::Relaxed) + 1;
let ts = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_nanos();
format!("{:x}-{:x}", ts, n)
}
impl Drop for Client {
fn drop(&mut self) {
self.stop_background_tasks();
let pending = match self.inner.signal_buffer.lock() {
Ok(buffer) => buffer.len(),
Err(err) => err.into_inner().len(),
};
if pending > 0 {
self.inner.log(format!(
"client dropped with {pending} unflushed signals — call shutdown().await before dropping to avoid signal loss"
));
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{Operator, PropertyValue, Rule};
use opentelemetry::trace::{Span, TraceContextExt, Tracer, TracerProvider};
use opentelemetry_sdk::trace::SdkTracerProvider;
use std::io::{Read, Write};
use std::net::TcpListener;
use std::thread;
use tokio::runtime::{Builder, Runtime};
fn make_client() -> (Runtime, Client) {
let runtime = Builder::new_multi_thread().enable_all().build().unwrap();
let client = runtime
.block_on(Client::init(
"test-key",
ClientOptions {
backend_url: "http://localhost:9999".into(),
..Default::default()
},
))
.unwrap();
(runtime, client)
}
fn adopted_guard(
name: &str,
default_value: bool,
rules: Vec<Rule>,
rate_limit_per_minute: i32,
rate_limit_properties: Vec<&str>,
) -> Guard {
Guard {
name: name.into(),
rules,
default_value,
adopted: true,
rate_limit_per_minute,
rate_limit_properties: rate_limit_properties
.into_iter()
.map(str::to_owned)
.collect(),
disable_measurement: None,
}
}
fn rule(prop: &str, operator: Operator, values: Vec<PropertyValue>, result: bool) -> Rule {
Rule {
property_name: prop.into(),
operator,
values,
result,
enabled: true,
}
}
fn with_props(properties: Properties) -> Options {
Options {
properties,
fallback: None,
disable_measurement: false,
}
}
fn read_request(stream: &mut std::net::TcpStream) -> String {
let mut request = Vec::new();
let mut headers = [0_u8; 4096];
let size = stream.read(&mut headers).unwrap();
request.extend_from_slice(&headers[..size]);
let header_end = request.windows(4).position(|w| w == b"\r\n\r\n").unwrap() + 4;
let header_text = String::from_utf8_lossy(&request[..header_end]);
let content_length = header_text
.lines()
.find_map(|line| line.strip_prefix("Content-Length: "))
.and_then(|value| value.trim().parse::<usize>().ok())
.unwrap_or(0);
let already_read = request.len() - header_end;
if content_length > already_read {
let mut body = vec![0_u8; content_length - already_read];
stream.read_exact(&mut body).unwrap();
request.extend_from_slice(&body);
}
String::from_utf8_lossy(&request).into_owned()
}
fn write_response(stream: &mut std::net::TcpStream, status: u16, body: &str) {
let response = format!(
"HTTP/1.1 {status} TEST\r\nContent-Length: {}\r\nContent-Type: application/json\r\nConnection: close\r\n\r\n{body}",
body.len()
);
stream.write_all(response.as_bytes()).unwrap();
}
#[test]
fn options_with_defaults_normalize_invalid_values() {
let opts = ClientOptions {
environment: None,
fallback: false,
refresh_rate_seconds: 0,
flush_rate_seconds: 0,
flush_size: 0,
backend_url: String::new(),
quiet: true,
http_timeout_seconds: 0,
flush_buffer_multiplier: 0,
disable_measurement: false,
}
.with_defaults();
assert_eq!(opts.refresh_rate_seconds, 30);
assert_eq!(opts.flush_rate_seconds, 10);
assert_eq!(opts.flush_size, 500);
assert_eq!(opts.backend_url, "https://api.liteguard.io");
assert_eq!(opts.http_timeout_seconds, 4);
assert_eq!(opts.flush_buffer_multiplier, 4);
}
#[test]
fn evaluates_rules_using_explicit_scopes_and_per_call_properties() {
let (_runtime, client) = make_client();
client.set_guards_for_testing(vec![
adopted_guard(
"feature.plan",
false,
vec![rule(
"plan",
Operator::Equals,
vec![PropertyValue::Text("pro".into())],
true,
)],
0,
vec![],
),
adopted_guard(
"feature.canary",
false,
vec![rule(
"userId",
Operator::In,
vec![PropertyValue::Text("canary-1".into())],
true,
)],
0,
vec![],
),
]);
let scope = client.create_scope(
Properties::new()
.set("plan", "free")
.set("userId", "normal-user"),
);
assert!(!scope.is_open("feature.plan"));
assert!(scope.is_open_with_options(
"feature.canary",
&with_props(Properties::new().set("userId", "canary-1")),
));
}
#[test]
fn scopes_isolate_request_local_properties_across_threads() {
let (_runtime, client) = make_client();
client.set_guards_for_testing(vec![adopted_guard(
"feature.plan",
false,
vec![rule(
"plan",
Operator::Equals,
vec![PropertyValue::Text("pro".into())],
true,
)],
0,
vec![],
)]);
let free_scope = client.create_scope(Properties::new().set("requestId", "free"));
let pro_scope = client.create_scope(Properties::new().set("requestId", "pro"));
let ready = Arc::new(std::sync::Barrier::new(3));
let release = Arc::new(std::sync::Barrier::new(3));
let free_ready = ready.clone();
let free_release = release.clone();
let free = thread::spawn(move || {
let scope = free_scope.add_properties(Properties::new().set("plan", "free"));
free_ready.wait();
free_release.wait();
scope.is_open("feature.plan")
});
let pro_ready = ready.clone();
let pro_release = release.clone();
let pro = thread::spawn(move || {
let scope = pro_scope.add_properties(Properties::new().set("plan", "pro"));
pro_ready.wait();
pro_release.wait();
scope.is_open("feature.plan")
});
ready.wait();
release.wait();
let results = [free.join().unwrap(), pro.join().unwrap()];
assert!(results.contains(&false));
assert!(results.contains(&true));
}
#[test]
fn protected_context_binds_to_derived_scopes_and_caches_bundles() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let address = listener.local_addr().unwrap();
let request_bodies = Arc::new(Mutex::new(Vec::<String>::new()));
let request_bodies_for_server = Arc::clone(&request_bodies);
let server = thread::spawn(move || {
for _ in 0..3 {
let (mut stream, _) = listener.accept().unwrap();
let request = read_request(&mut stream);
request_bodies_for_server
.lock()
.unwrap()
.push(request.clone());
let body = if request.contains("\"protectedContext\"") {
"{\"guards\":[{\"name\":\"private.dashboard\",\"rules\":[],\"defaultValue\":true,\"adopted\":true,\"rateLimitPerMinute\":0,\"rateLimitProperties\":[]}],\"refreshRateSeconds\":30,\"etag\":\"protected\"}"
} else {
"{\"guards\":[{\"name\":\"private.dashboard\",\"rules\":[],\"defaultValue\":false,\"adopted\":true,\"rateLimitPerMinute\":0,\"rateLimitProperties\":[]}],\"refreshRateSeconds\":30,\"etag\":\"public\"}"
};
write_response(&mut stream, 200, body);
}
});
let runtime = Builder::new_multi_thread().enable_all().build().unwrap();
let client = runtime
.block_on(Client::init(
"test-key",
ClientOptions {
backend_url: format!("http://{address}"),
refresh_rate_seconds: 3600,
flush_rate_seconds: 3600,
..Default::default()
},
))
.unwrap();
let public_scope = client.create_scope(Properties::new());
let protected_context_a = ProtectedContext {
properties: HashMap::from([(String::from("email"), String::from("alice@acme.com"))]),
signature: "sig-123".into(),
};
let protected_context_b = ProtectedContext {
properties: HashMap::from([(String::from("email"), String::from("bob@acme.com"))]),
signature: "sig-456".into(),
};
let protected_scope = runtime
.block_on(public_scope.bind_protected_context(protected_context_a.clone()))
.unwrap();
let cached_scope = runtime
.block_on(
client
.create_scope(Properties::new())
.bind_protected_context(protected_context_a),
)
.unwrap();
let other_scope = runtime
.block_on(
client
.create_scope(Properties::new())
.bind_protected_context(protected_context_b),
)
.unwrap();
assert!(!public_scope.is_open("private.dashboard"));
assert!(protected_scope.is_open("private.dashboard"));
assert!(cached_scope.is_open("private.dashboard"));
assert!(other_scope.is_open("private.dashboard"));
assert!(!runtime
.block_on(protected_scope.clear_protected_context())
.unwrap()
.is_open("private.dashboard"));
assert_eq!(client.known_bundle_count_for_testing(), 3);
drop(client);
server.join().unwrap();
let requests = request_bodies.lock().unwrap();
assert_eq!(requests.len(), 3);
assert!(requests[0].starts_with("POST /api/v1/guards"));
assert!(requests[1].contains("alice@acme.com"));
assert!(requests[2].contains("bob@acme.com"));
}
#[test]
fn refreshes_all_cached_bundles_using_shortest_interval() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let address = listener.local_addr().unwrap();
let server = thread::spawn(move || {
for body in [
"{\"guards\":[],\"refreshRateSeconds\":90,\"etag\":\"public\"}",
"{\"guards\":[],\"refreshRateSeconds\":45,\"etag\":\"protected\"}",
] {
let (mut stream, _) = listener.accept().unwrap();
let _ = read_request(&mut stream);
write_response(&mut stream, 200, body);
}
});
let runtime = Builder::new_multi_thread().enable_all().build().unwrap();
let client = runtime
.block_on(Client::init(
"test-key",
ClientOptions {
backend_url: format!("http://{address}"),
refresh_rate_seconds: 30,
flush_rate_seconds: 3600,
..Default::default()
},
))
.unwrap();
assert_eq!(client.current_refresh_rate_for_testing(), 90);
let _ = runtime
.block_on(
client
.create_scope(Properties::new())
.bind_protected_context(ProtectedContext {
properties: HashMap::from([(
String::from("email"),
String::from("alice@acme.com"),
)]),
signature: "sig-123".into(),
}),
)
.unwrap();
assert_eq!(client.current_refresh_rate_for_testing(), 45);
drop(client);
server.join().unwrap();
}
#[test]
fn server_can_speed_up_refreshes_below_client_default() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let address = listener.local_addr().unwrap();
let server = thread::spawn(move || {
let (mut stream, _) = listener.accept().unwrap();
let _ = read_request(&mut stream);
write_response(
&mut stream,
200,
"{\"guards\":[],\"refreshRateSeconds\":5,\"etag\":\"public\"}",
);
});
let runtime = Builder::new_multi_thread().enable_all().build().unwrap();
let client = runtime
.block_on(Client::init(
"test-key",
ClientOptions {
backend_url: format!("http://{address}"),
refresh_rate_seconds: 30,
flush_rate_seconds: 3600,
..Default::default()
},
))
.unwrap();
assert_eq!(client.current_refresh_rate_for_testing(), 5);
drop(client);
server.join().unwrap();
}
#[test]
fn fallback_is_used_before_initial_load() {
let runtime = Builder::new_multi_thread().enable_all().build().unwrap();
let fallback_true = runtime
.block_on(Client::init(
"test-key",
ClientOptions {
backend_url: "http://localhost:9999".into(),
fallback: true,
..Default::default()
},
))
.unwrap();
assert!(fallback_true.is_open("unknown"));
let fallback_false = runtime
.block_on(Client::init(
"test-key",
ClientOptions {
backend_url: "http://localhost:9999".into(),
fallback: false,
..Default::default()
},
))
.unwrap();
assert!(!fallback_false.is_open("unknown"));
}
#[test]
fn unadopted_guard_is_reported_once() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let address = listener.local_addr().unwrap();
let server = thread::spawn(move || {
for status in [200, 204] {
let (mut stream, _) = listener.accept().unwrap();
let request = read_request(&mut stream);
let body = if status == 200 {
"{\"guards\":[],\"refreshRateSeconds\":30,\"etag\":\"etag\"}"
} else {
assert!(request.starts_with("POST /api/v1/unadopted-guards"));
assert_eq!(request.matches("\"g\"").count(), 1);
""
};
write_response(&mut stream, status, body);
}
});
let runtime = Builder::new_multi_thread().enable_all().build().unwrap();
let client = runtime
.block_on(Client::init(
"test-key",
ClientOptions {
backend_url: format!("http://{address}"),
..Default::default()
},
))
.unwrap();
client.set_guards_for_testing(vec![Guard {
name: "g".into(),
rules: vec![],
default_value: false,
adopted: false,
rate_limit_per_minute: 0,
rate_limit_properties: vec![],
disable_measurement: None,
}]);
assert!(client.is_open("g"));
assert!(client.is_open("g"));
assert_eq!(
client.pending_unadopted_guards_for_testing(),
vec!["g".to_string()]
);
runtime.block_on(client.flush()).unwrap();
assert!(client.pending_unadopted_guards_for_testing().is_empty());
drop(client);
server.join().unwrap();
}
#[test]
fn signal_includes_tracing_metadata() {
let (_runtime, client) = make_client();
client.set_guards_for_testing(vec![adopted_guard("g", true, vec![], 0, vec![])]);
assert!(client.is_open("g"));
let buffer = client.inner.signal_buffer.lock().unwrap();
let signal = &buffer[0];
assert!(!signal.signal_id.is_empty());
assert!(!signal.execution_id.is_empty());
assert_eq!(signal.sequence_number, 1);
assert!(!signal.callsite_id.is_empty());
assert_eq!(signal.kind, "guard_check");
assert!(signal.measurement.is_none());
}
#[test]
fn with_execution_links_signals() {
let (_runtime, client) = make_client();
client.set_guards_for_testing(vec![
adopted_guard("a", true, vec![], 0, vec![]),
adopted_guard("b", true, vec![], 0, vec![]),
]);
client.with_execution(|| {
assert!(client.is_open("a"));
assert!(client.is_open("b"));
});
let buffer = client.inner.signal_buffer.lock().unwrap();
assert_eq!(buffer.len(), 2);
assert_eq!(buffer[0].execution_id, buffer[1].execution_id);
assert_eq!(
buffer[1].parent_signal_id.as_deref(),
Some(buffer[0].signal_id.as_str())
);
}
#[test]
fn execute_if_open_emits_execution_telemetry() {
let (_runtime, client) = make_client();
client.set_guards_for_testing(vec![adopted_guard("g", true, vec![], 0, vec![])]);
let result = client.execute_if_open("g", || 42);
assert_eq!(result, Some(42));
let buffer = client.inner.signal_buffer.lock().unwrap();
assert_eq!(buffer.len(), 2);
assert_eq!(buffer[1].kind, "guard_execution");
assert_eq!(
buffer[1].parent_signal_id.as_deref(),
Some(buffer[0].signal_id.as_str())
);
let measurement = buffer[1].measurement.as_ref().unwrap();
let execution = measurement.guard_execution.as_ref().unwrap();
assert!(execution.completed);
assert!(execution.error_class.is_none());
}
#[test]
fn try_execute_if_open_records_failed_execution_results() {
let (_runtime, client) = make_client();
client.set_guards_for_testing(vec![adopted_guard("g", true, vec![], 0, vec![])]);
let result = client.try_execute_if_open("g", || -> Result<i32, std::io::Error> {
Err(std::io::Error::other("boom"))
});
assert!(result.is_err());
let buffer = client.inner.signal_buffer.lock().unwrap();
assert_eq!(buffer.len(), 2);
let measurement = buffer[1].measurement.as_ref().unwrap();
let execution = measurement.guard_execution.as_ref().unwrap();
assert!(!execution.completed);
assert_eq!(
execution.error_class.as_deref(),
Some("std::io::error::Error")
);
}
#[test]
fn execution_context_clears_after_panics() {
let (_runtime, client) = make_client();
client.set_guards_for_testing(vec![adopted_guard("g", true, vec![], 0, vec![])]);
let panic = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
client.execute_if_open("g", || -> i32 { panic!("boom") });
}));
assert!(panic.is_err());
assert!(client.is_open("g"));
let buffer = client.inner.signal_buffer.lock().unwrap();
assert_eq!(buffer.len(), 3);
assert_ne!(buffer[2].execution_id, buffer[0].execution_id);
assert!(buffer[2].parent_signal_id.is_none());
assert_eq!(buffer[2].sequence_number, 1);
}
#[test]
fn signals_capture_active_trace_context() {
let (_runtime, client) = make_client();
client.set_guards_for_testing(vec![adopted_guard("g", true, vec![], 0, vec![])]);
let provider = SdkTracerProvider::builder().build();
let tracer = provider.tracer("liteguard-test");
let span = tracer.start("guard-check");
let expected_trace_id = span.span_context().trace_id().to_string();
let expected_span_id = span.span_context().span_id().to_string();
let context = opentelemetry::Context::current_with_span(span);
let guard = context.attach();
assert!(client.is_open("g"));
drop(guard);
let _ = provider.shutdown();
let buffer = client.inner.signal_buffer.lock().unwrap();
let trace = buffer[0]
.trace
.as_ref()
.expect("trace context should be captured");
assert_eq!(trace.trace_id, expected_trace_id);
assert_eq!(trace.span_id, expected_span_id);
assert!(trace.parent_span_id.is_empty());
}
#[test]
fn measurement_is_omitted_when_disabled_by_client_check_or_guard() {
let runtime = Builder::new_multi_thread().enable_all().build().unwrap();
let client = runtime
.block_on(Client::init(
"test-key",
ClientOptions {
backend_url: "http://localhost:9999".into(),
disable_measurement: true,
..Default::default()
},
))
.unwrap();
client.set_guards_for_testing(vec![adopted_guard("g", true, vec![], 0, vec![])]);
assert!(client.is_open("g"));
assert!(client.inner.signal_buffer.lock().unwrap()[0]
.measurement
.is_none());
let (_runtime, client) = make_client();
client.set_guards_for_testing(vec![adopted_guard("g", true, vec![], 0, vec![])]);
assert!(client.is_open_with_options(
"g",
&Options {
disable_measurement: true,
..Default::default()
},
));
assert!(client.inner.signal_buffer.lock().unwrap()[0]
.measurement
.is_none());
let (_runtime, client) = make_client();
client.set_guards_for_testing(vec![Guard {
name: "g".into(),
rules: vec![],
default_value: true,
adopted: true,
rate_limit_per_minute: 0,
rate_limit_properties: vec![],
disable_measurement: Some(true),
}]);
assert_eq!(client.execute_if_open("g", || 1), Some(1));
let buffer = client.inner.signal_buffer.lock().unwrap();
assert!(buffer[0].measurement.is_none());
assert!(buffer[1].measurement.is_none());
}
#[test]
fn drops_oldest_signals_when_buffer_is_full() {
let runtime = Builder::new_multi_thread().enable_all().build().unwrap();
let client = runtime
.block_on(Client::init(
"test-key",
ClientOptions {
backend_url: "http://localhost:9999".into(),
flush_size: 2,
flush_buffer_multiplier: 2,
..Default::default()
},
))
.unwrap();
client.set_guards_for_testing(vec![adopted_guard("g", true, vec![], 0, vec![])]);
for _ in 0..5 {
assert!(client.is_open("g"));
}
runtime.block_on(async {
tokio::time::timeout(Duration::from_millis(250), async {
loop {
if client.inner.signal_buffer.lock().unwrap().len() == 4 {
break;
}
tokio::task::yield_now().await;
}
})
.await
.unwrap();
});
let oldest_signal_id = client.inner.signal_buffer.lock().unwrap()[0]
.signal_id
.clone();
assert!(client.is_open("g"));
runtime.block_on(async {
tokio::time::timeout(Duration::from_millis(250), async {
loop {
let done = {
let buffer = client.inner.signal_buffer.lock().unwrap();
buffer.len() == 4
&& buffer
.iter()
.all(|signal| signal.signal_id != oldest_signal_id)
&& buffer
.last()
.map(|signal| signal.dropped_signals_since_last >= 1)
.unwrap_or(false)
};
if done {
break;
}
tokio::task::yield_now().await;
}
})
.await
.unwrap();
});
}
#[test]
fn rate_limit_applies_property_scoped_buckets_across_scopes() {
let (_runtime, client) = make_client();
client.set_guards_for_testing(vec![adopted_guard("rl", true, vec![], 1, vec!["region"])]);
let us_scope = client.create_scope(Properties::new().set("region", "us-east"));
let eu_scope = client.create_scope(Properties::new().set("region", "eu-west"));
assert!(us_scope.is_open("rl"));
assert!(!us_scope.is_open("rl"));
assert!(!us_scope.peek_is_open("rl"));
assert!(eu_scope.is_open("rl"));
}
#[test]
fn closed_evaluations_do_not_consume_quota() {
let (_runtime, client) = make_client();
client.set_guards_for_testing(vec![adopted_guard(
"rl",
false,
vec![rule(
"plan",
Operator::Equals,
vec![PropertyValue::Text("pro".into())],
true,
)],
1,
vec![],
)]);
let free_scope = client.create_scope(Properties::new().set("plan", "free"));
assert!(!free_scope.is_open("rl"));
assert!(!free_scope.is_open("rl"));
let pro_scope = client.create_scope(Properties::new().set("plan", "pro"));
assert!(pro_scope.is_open("rl"));
assert!(!pro_scope.is_open("rl"));
}
#[test]
fn environment_and_protected_context_are_sent_with_normalized_paths() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let address = listener.local_addr().unwrap();
let requests = Arc::new(Mutex::new(Vec::<String>::new()));
let requests_for_server = Arc::clone(&requests);
let server = thread::spawn(move || {
for body in [
"{\"guards\":[],\"refreshRateSeconds\":30,\"etag\":\"etag-1\"}",
"",
"",
"{\"guards\":[],\"refreshRateSeconds\":30,\"etag\":\"etag-2\"}",
] {
let (mut stream, _) = listener.accept().unwrap();
let request = read_request(&mut stream);
requests_for_server.lock().unwrap().push(request);
write_response(&mut stream, if body.is_empty() { 204 } else { 200 }, body);
}
});
let runtime = Builder::new_multi_thread().enable_all().build().unwrap();
let client = runtime
.block_on(Client::init(
"test-key",
ClientOptions {
backend_url: format!("http://{address}/"),
environment: Some("production".into()),
refresh_rate_seconds: 3600,
flush_rate_seconds: 3600,
..Default::default()
},
))
.unwrap();
client.set_guards_for_testing(vec![
adopted_guard("open", true, vec![], 0, vec![]),
Guard {
name: "unadopted".into(),
rules: vec![],
default_value: false,
adopted: false,
rate_limit_per_minute: 0,
rate_limit_properties: vec![],
disable_measurement: None,
},
]);
assert!(client.is_open("open"));
assert!(client.is_open("unadopted"));
runtime.block_on(client.flush()).unwrap();
let scope = runtime
.block_on(
client
.create_scope(Properties::new())
.bind_protected_context(ProtectedContext {
properties: HashMap::from([(
String::from("email"),
String::from("alice@acme.com"),
)]),
signature: "sig-123".into(),
}),
)
.unwrap();
assert!(scope.protected_context().is_some());
runtime.block_on(scope.clear_protected_context()).unwrap();
drop(client);
server.join().unwrap();
let requests = requests.lock().unwrap();
assert!(requests[0].starts_with("POST /api/v1/guards"));
assert!(requests[1].starts_with("POST /api/v1/signals"));
assert!(requests[1].contains("\"environment\":\"production\""));
assert!(requests[1].contains("\"measurement\""));
assert!(requests[2].starts_with("POST /api/v1/unadopted-guards"));
assert!(requests[2].contains("\"environment\":\"production\""));
assert!(requests[3].starts_with("POST /api/v1/guards"));
assert!(requests[3].contains("\"protectedContext\""));
assert!(requests[3].contains("\"signature\":\"sig-123\""));
}
#[test]
fn flush_requeues_on_http_status_error() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let address = listener.local_addr().unwrap();
let server = thread::spawn(move || {
for status in [200, 500, 204] {
let (mut stream, _) = listener.accept().unwrap();
let _ = read_request(&mut stream);
let body = if status == 200 {
"{\"guards\":[],\"refreshRateSeconds\":30,\"etag\":\"etag\"}"
} else {
""
};
write_response(&mut stream, status, body);
}
});
let runtime = Builder::new_multi_thread().enable_all().build().unwrap();
let client = runtime
.block_on(Client::init(
"test-key",
ClientOptions {
backend_url: format!("http://{address}"),
..Default::default()
},
))
.unwrap();
client.set_guards_for_testing(vec![adopted_guard("g", true, vec![], 0, vec![])]);
assert!(client.is_open("g"));
assert!(runtime.block_on(client.flush()).is_err());
assert_eq!(client.inner.signal_buffer.lock().unwrap().len(), 1);
assert!(runtime.block_on(client.flush()).is_ok());
assert_eq!(client.inner.signal_buffer.lock().unwrap().len(), 0);
drop(client);
server.join().unwrap();
}
}