use crate::decision::Execution;
use crate::evaluation::{evaluate_guard, evaluate_guard_detailed};
use crate::types::{
ClientOptions, Guard, GuardDecision, GuardDecisionReason, GuardExecutionPerformance, InitError,
Options, Properties, ProtectedContext, RateLimitDecision, RateLimitDecisionOutcome,
RateLimitEvaluationTarget, SendUnadoptedGuardsRequest, Signal, SignalPerformance, TraceContext,
UnadoptedGuardObservation,
};
use opentelemetry::trace::TraceContextExt;
use reqwest::Client as HttpClient;
use serde::Deserialize;
use serde_json::{Map, Value};
use std::cell::RefCell;
use std::collections::HashMap;
use std::panic::Location;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::{Arc, Mutex, RwLock};
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
use tokio::runtime::Handle;
use tokio::sync::Notify;
use tokio::task::JoinHandle;
use tokio::time;
thread_local! {
static EXECUTION_STATE: RefCell<Option<ExecutionState>> = const { RefCell::new(None) };
}
static SIGNAL_COUNTER: AtomicU64 = AtomicU64::new(0);
const PUBLIC_BUNDLE_KEY: &str = "";
#[derive(Clone)]
struct ExecutionState {
execution_id: String,
sequence_number: i64,
last_signal_id: Option<String>,
}
struct SignalMetadata {
signal_id: String,
execution_id: String,
parent_signal_id: Option<String>,
sequence_number: i64,
}
struct GuardEvaluation {
result: bool,
guard: Option<Guard>,
props: Option<Properties>,
signal: Option<Signal>,
}
struct BufferedSignalInput {
guard_name: String,
result: bool,
props: Properties,
metadata: SignalMetadata,
callsite_id: String,
kind: String,
measurement: Option<SignalPerformance>,
rate_limit_decisions: Vec<RateLimitDecision>,
}
struct RateLimitEvaluation {
allowed: bool,
count_in_window: i32,
}
struct AppliedRateLimit {
result: bool,
rate_limit_decisions: Vec<RateLimitDecision>,
}
#[derive(Clone)]
struct GuardBundle {
key: String,
guards: HashMap<String, Guard>,
ready: bool,
etag: String,
protected_context: Option<ProtectedContext>,
refresh_rate_seconds: u64,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct GuardsResponse {
guards: Vec<Guard>,
#[serde(default)]
refresh_rate_seconds: u64,
#[serde(default)]
etag: String,
}
#[derive(Clone)]
struct RateLimitEntry {
window_start: Instant,
count: i32,
}
#[derive(Clone)]
struct PendingUnadoptedGuardObservation {
first_seen_ms: i64,
last_seen_ms: i64,
check_count: i64,
}
struct Inner {
project_client_token: String,
opts: ClientOptions,
runtime: Handle,
http: HttpClient,
bundles: RwLock<HashMap<String, GuardBundle>>,
current_refresh_rate_seconds: Mutex<u64>,
refresh_notify: Notify,
signal_buffer: Mutex<Vec<Signal>>,
dropped_signals_pending: Mutex<i32>,
pending_unadopted_guards: Mutex<HashMap<String, PendingUnadoptedGuardObservation>>,
rate_limit_state: Mutex<HashMap<String, RateLimitEntry>>,
}
fn estimate_checks_per_minute(first_seen_ms: i64, last_seen_ms: i64, check_count: i64) -> f64 {
if check_count <= 0 {
return 0.0;
}
let window_ms = (last_seen_ms - first_seen_ms).max(1000) as f64;
(check_count as f64 * 60_000.0) / window_ms
}
impl Inner {
fn log(&self, message: impl AsRef<str>) {
if self.opts.quiet {
return;
}
eprintln!("[liteguard] {}", message.as_ref());
}
}
#[derive(Clone)]
pub struct Scope {
inner: Arc<Inner>,
properties: Properties,
bundle_key: String,
protected_context: Option<ProtectedContext>,
}
impl Scope {
pub fn properties(&self) -> &Properties {
&self.properties
}
pub fn protected_context(&self) -> Option<&ProtectedContext> {
self.protected_context.as_ref()
}
pub fn with_properties(&self, properties: Properties) -> Scope {
let mut merged = self.properties.clone();
for (key, value) in properties.0 {
merged.insert(key, value);
}
Scope {
inner: Arc::clone(&self.inner),
properties: merged,
bundle_key: self.bundle_key.clone(),
protected_context: self.protected_context.clone(),
}
}
pub fn add_properties(&self, properties: Properties) -> Scope {
self.with_properties(properties)
}
pub fn clear_properties(&self, names: &[&str]) -> Scope {
let mut cleared = self.properties.clone();
for name in names {
cleared.remove(name);
}
Scope {
inner: Arc::clone(&self.inner),
properties: cleared,
bundle_key: self.bundle_key.clone(),
protected_context: self.protected_context.clone(),
}
}
pub fn reset_properties(&self) -> Scope {
Scope {
inner: Arc::clone(&self.inner),
properties: Properties::new(),
bundle_key: self.bundle_key.clone(),
protected_context: self.protected_context.clone(),
}
}
pub fn belongs_to(&self, client: &Client) -> bool {
Arc::ptr_eq(&self.inner, &client.inner)
}
pub async fn bind_protected_context(
&self,
protected_context: ProtectedContext,
) -> Result<Scope, reqwest::Error> {
let bundle_key =
Client::ensure_bundle_for_protected_context(&self.inner, protected_context.clone())
.await?;
Ok(Scope {
inner: Arc::clone(&self.inner),
properties: self.properties.clone(),
bundle_key,
protected_context: Some(protected_context),
})
}
pub async fn clear_protected_context(&self) -> Result<Scope, reqwest::Error> {
Client::ensure_public_bundle_ready(&self.inner).await?;
Ok(Scope {
inner: Arc::clone(&self.inner),
properties: self.properties.clone(),
bundle_key: PUBLIC_BUNDLE_KEY.to_owned(),
protected_context: None,
})
}
#[track_caller]
pub fn is_open(&self, name: &str) -> bool {
self.is_open_with_options(name, &Options::default())
}
#[track_caller]
pub fn is_open_with_options(&self, name: &str, options: &Options) -> bool {
Client::evaluate_guard_in_scope(&self.inner, self, name, options, true, Location::caller())
.result
}
#[track_caller]
pub fn peek_is_open(&self, name: &str) -> bool {
self.peek_is_open_with_options(name, &Options::default())
}
#[track_caller]
pub fn peek_is_open_with_options(&self, name: &str, options: &Options) -> bool {
Client::evaluate_guard_in_scope(&self.inner, self, name, options, false, Location::caller())
.result
}
#[track_caller]
pub fn evaluate(&self, name: &str) -> GuardDecision {
self.evaluate_with_options(name, &Options::default())
}
#[track_caller]
pub fn evaluate_with_options(&self, name: &str, options: &Options) -> GuardDecision {
Client::evaluate_guard_decision(&self.inner, self, name, options, Location::caller())
}
pub fn start_execution(&self) -> Execution {
Client::start_execution_inner(&self.inner)
}
#[track_caller]
pub fn execute_if_open<F, R>(&self, name: &str, f: F) -> Option<R>
where
F: FnOnce() -> R,
{
self.execute_if_open_with_options(name, &Options::default(), f)
}
#[track_caller]
pub fn execute_if_open_with_options<F, R>(
&self,
name: &str,
options: &Options,
f: F,
) -> Option<R>
where
F: FnOnce() -> R,
{
Client::with_execution_inner(&self.inner, || {
let evaluation = Client::evaluate_guard_in_scope(
&self.inner,
self,
name,
options,
true,
Location::caller(),
);
if !evaluation.result {
return None;
}
let guard_check_signal = evaluation.signal;
if guard_check_signal.is_none() {
return Some(f());
}
let guard = evaluation
.guard
.as_ref()
.expect("guard should exist when signal exists");
let props = evaluation
.props
.expect("props should exist when signal exists");
let measurement_enabled = Client::is_measurement_enabled(&self.inner, guard, options);
let started_at = Instant::now();
let guard_check_signal_id = guard_check_signal
.as_ref()
.map(|signal| signal.signal_id.clone())
.expect("signal id should exist");
let callsite_id = format!(
"{}:{}",
Location::caller().file(),
Location::caller().line()
);
match std::panic::catch_unwind(std::panic::AssertUnwindSafe(f)) {
Ok(value) => {
Client::buffer_signal_inner(
&self.inner,
BufferedSignalInput {
guard_name: name.to_owned(),
result: true,
props,
metadata: Client::next_signal_metadata(Some(guard_check_signal_id)),
callsite_id,
kind: "guard_execution".into(),
measurement: if measurement_enabled {
Some(Client::capture_guard_execution_measurement(
started_at, true, None,
))
} else {
None
},
rate_limit_decisions: vec![],
},
);
Some(value)
}
Err(payload) => {
Client::buffer_signal_inner(
&self.inner,
BufferedSignalInput {
guard_name: name.to_owned(),
result: true,
props,
metadata: Client::next_signal_metadata(Some(guard_check_signal_id)),
callsite_id,
kind: "guard_execution".into(),
measurement: if measurement_enabled {
Some(Client::capture_guard_execution_measurement(
started_at,
false,
Some(Client::panic_error_class(payload.as_ref())),
))
} else {
None
},
rate_limit_decisions: vec![],
},
);
std::panic::resume_unwind(payload)
}
}
})
}
#[track_caller]
pub fn try_execute_if_open<F, R, E>(&self, name: &str, f: F) -> Result<Option<R>, E>
where
F: FnOnce() -> Result<R, E>,
{
self.try_execute_if_open_with_options(name, &Options::default(), f)
}
#[track_caller]
pub fn try_execute_if_open_with_options<F, R, E>(
&self,
name: &str,
options: &Options,
f: F,
) -> Result<Option<R>, E>
where
F: FnOnce() -> Result<R, E>,
{
Client::with_execution_inner(&self.inner, || {
let evaluation = Client::evaluate_guard_in_scope(
&self.inner,
self,
name,
options,
true,
Location::caller(),
);
if !evaluation.result {
return Ok(None);
}
let guard_check_signal = evaluation.signal;
if guard_check_signal.is_none() {
return f().map(Some);
}
let guard = evaluation
.guard
.as_ref()
.expect("guard should exist when signal exists");
let props = evaluation
.props
.expect("props should exist when signal exists");
let measurement_enabled = Client::is_measurement_enabled(&self.inner, guard, options);
let started_at = Instant::now();
let guard_check_signal_id = guard_check_signal
.as_ref()
.map(|signal| signal.signal_id.clone())
.expect("signal id should exist");
let callsite_id = format!(
"{}:{}",
Location::caller().file(),
Location::caller().line()
);
match std::panic::catch_unwind(std::panic::AssertUnwindSafe(f)) {
Ok(Ok(value)) => {
Client::buffer_signal_inner(
&self.inner,
BufferedSignalInput {
guard_name: name.to_owned(),
result: true,
props,
metadata: Client::next_signal_metadata(Some(guard_check_signal_id)),
callsite_id,
kind: "guard_execution".into(),
measurement: if measurement_enabled {
Some(Client::capture_guard_execution_measurement(
started_at, true, None,
))
} else {
None
},
rate_limit_decisions: vec![],
},
);
Ok(Some(value))
}
Ok(Err(error)) => {
Client::buffer_signal_inner(
&self.inner,
BufferedSignalInput {
guard_name: name.to_owned(),
result: true,
props,
metadata: Client::next_signal_metadata(Some(guard_check_signal_id)),
callsite_id,
kind: "guard_execution".into(),
measurement: if measurement_enabled {
Some(Client::capture_guard_execution_measurement(
started_at,
false,
Some(std::any::type_name::<E>().to_owned()),
))
} else {
None
},
rate_limit_decisions: vec![],
},
);
Err(error)
}
Err(payload) => {
Client::buffer_signal_inner(
&self.inner,
BufferedSignalInput {
guard_name: name.to_owned(),
result: true,
props,
metadata: Client::next_signal_metadata(Some(guard_check_signal_id)),
callsite_id,
kind: "guard_execution".into(),
measurement: if measurement_enabled {
Some(Client::capture_guard_execution_measurement(
started_at,
false,
Some(Client::panic_error_class(payload.as_ref())),
))
} else {
None
},
rate_limit_decisions: vec![],
},
);
std::panic::resume_unwind(payload)
}
}
})
}
}
pub struct Client {
inner: Arc<Inner>,
refresh_task: JoinHandle<()>,
flush_task: JoinHandle<()>,
is_shutdown: AtomicBool,
}
impl Client {
pub async fn init(
project_client_token: impl Into<String>,
opts: ClientOptions,
) -> Result<Client, InitError> {
let project_client_token = project_client_token.into();
let opts = opts.with_defaults();
let configured_refresh_rate_seconds = opts.refresh_rate_seconds;
let http = HttpClient::builder()
.timeout(Duration::from_secs(opts.http_timeout_seconds))
.build()
.map_err(|e| InitError(format!("failed to build HTTP client: {e}")))?;
let mut bundles = HashMap::new();
bundles.insert(
PUBLIC_BUNDLE_KEY.to_owned(),
Self::create_empty_bundle(
PUBLIC_BUNDLE_KEY.to_owned(),
None,
configured_refresh_rate_seconds,
),
);
let inner = Arc::new(Inner {
project_client_token,
opts,
runtime: Handle::current(),
http,
bundles: RwLock::new(bundles),
current_refresh_rate_seconds: Mutex::new(configured_refresh_rate_seconds),
refresh_notify: Notify::new(),
signal_buffer: Mutex::new(Vec::new()),
dropped_signals_pending: Mutex::new(0),
pending_unadopted_guards: Mutex::new(HashMap::new()),
rate_limit_state: Mutex::new(HashMap::new()),
});
let _ = Self::fetch_guards_for_bundle(&inner, PUBLIC_BUNDLE_KEY).await;
let refresh_task = {
let inner = Arc::clone(&inner);
tokio::spawn(async move {
loop {
let interval = *inner.current_refresh_rate_seconds.lock().unwrap();
if time::timeout(
Duration::from_secs(interval),
inner.refresh_notify.notified(),
)
.await
.is_ok()
{
continue;
}
let bundle_keys = {
let bundles = inner.bundles.read().unwrap();
let mut keys = bundles.keys().cloned().collect::<Vec<_>>();
keys.sort();
keys
};
for bundle_key in bundle_keys {
let _ = Client::fetch_guards_for_bundle(&inner, &bundle_key).await;
}
}
})
};
let flush_task = {
let inner = Arc::clone(&inner);
let interval = inner.opts.flush_rate_seconds;
tokio::spawn(async move {
let mut ticker = time::interval(Duration::from_secs(interval));
ticker.tick().await;
loop {
ticker.tick().await;
let _ = Client::flush_signals_inner(&inner).await;
}
})
};
Ok(Client {
inner,
refresh_task,
flush_task,
is_shutdown: AtomicBool::new(false),
})
}
pub fn create_scope(&self, properties: Properties) -> Scope {
Scope {
inner: Arc::clone(&self.inner),
properties,
bundle_key: PUBLIC_BUNDLE_KEY.to_owned(),
protected_context: None,
}
}
fn public_scope(&self) -> Scope {
self.create_scope(Properties::new())
}
#[track_caller]
pub fn is_open(&self, name: &str) -> bool {
self.is_open_with_options(name, &Options::default())
}
#[track_caller]
pub fn is_open_with_options(&self, name: &str, options: &Options) -> bool {
let scope = self.public_scope();
Self::evaluate_guard_in_scope(&self.inner, &scope, name, options, true, Location::caller())
.result
}
#[track_caller]
pub fn peek_is_open(&self, name: &str) -> bool {
self.peek_is_open_with_options(name, &Options::default())
}
#[track_caller]
pub fn peek_is_open_with_options(&self, name: &str, options: &Options) -> bool {
let scope = self.public_scope();
Self::evaluate_guard_in_scope(
&self.inner,
&scope,
name,
options,
false,
Location::caller(),
)
.result
}
#[track_caller]
pub fn evaluate(&self, name: &str) -> GuardDecision {
self.evaluate_with_options(name, &Options::default())
}
#[track_caller]
pub fn evaluate_with_options(&self, name: &str, options: &Options) -> GuardDecision {
let scope = self.public_scope();
Self::evaluate_guard_decision(&self.inner, &scope, name, options, Location::caller())
}
pub fn start_execution(&self) -> Execution {
Self::start_execution_inner(&self.inner)
}
#[track_caller]
pub fn execute_if_open<F, R>(&self, name: &str, f: F) -> Option<R>
where
F: FnOnce() -> R,
{
self.execute_if_open_with_options(name, &Options::default(), f)
}
#[track_caller]
pub fn execute_if_open_with_options<F, R>(
&self,
name: &str,
options: &Options,
f: F,
) -> Option<R>
where
F: FnOnce() -> R,
{
self.public_scope()
.execute_if_open_with_options(name, options, f)
}
#[track_caller]
pub fn try_execute_if_open<F, R, E>(&self, name: &str, f: F) -> Result<Option<R>, E>
where
F: FnOnce() -> Result<R, E>,
{
self.try_execute_if_open_with_options(name, &Options::default(), f)
}
#[track_caller]
pub fn try_execute_if_open_with_options<F, R, E>(
&self,
name: &str,
options: &Options,
f: F,
) -> Result<Option<R>, E>
where
F: FnOnce() -> Result<R, E>,
{
self.public_scope()
.try_execute_if_open_with_options(name, options, f)
}
pub fn with_execution<F, R>(&self, f: F) -> R
where
F: FnOnce() -> R,
{
Self::with_execution_inner(&self.inner, f)
}
fn with_execution_inner<F, R>(_inner: &Arc<Inner>, f: F) -> R
where
F: FnOnce() -> R,
{
EXECUTION_STATE.with(|slot| {
if slot.borrow().is_some() {
return f();
}
let previous = slot.replace(Some(ExecutionState {
execution_id: next_signal_id(),
sequence_number: 0,
last_signal_id: None,
}));
debug_assert!(previous.is_none());
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(f));
slot.replace(None);
match result {
Ok(value) => value,
Err(payload) => std::panic::resume_unwind(payload),
}
})
}
pub async fn flush(&self) -> Result<(), reqwest::Error> {
Self::flush_signals_inner(&self.inner).await
}
pub async fn shutdown(&self) -> Result<(), reqwest::Error> {
self.stop_background_tasks();
self.flush().await
}
pub fn set_guards_for_testing(&self, guards: Vec<Guard>) {
{
let mut bundles = self.inner.bundles.write().unwrap();
bundles.insert(
PUBLIC_BUNDLE_KEY.to_owned(),
GuardBundle {
key: PUBLIC_BUNDLE_KEY.to_owned(),
guards: guards
.into_iter()
.map(|guard| (guard.name.clone(), guard))
.collect(),
ready: true,
etag: String::new(),
protected_context: None,
refresh_rate_seconds: self.inner.opts.refresh_rate_seconds,
},
);
*self.inner.current_refresh_rate_seconds.lock().unwrap() =
Self::recompute_refresh_interval(&bundles, self.inner.opts.refresh_rate_seconds);
}
}
pub fn known_bundle_count_for_testing(&self) -> usize {
self.inner.bundles.read().unwrap().len()
}
pub fn current_refresh_rate_for_testing(&self) -> u64 {
*self.inner.current_refresh_rate_seconds.lock().unwrap()
}
pub fn pending_unadopted_guards_for_testing(&self) -> Vec<UnadoptedGuardObservation> {
let mut guards = self
.inner
.pending_unadopted_guards
.lock()
.unwrap()
.iter()
.map(|(guard_name, observation)| UnadoptedGuardObservation {
guard_name: guard_name.clone(),
first_seen_ms: observation.first_seen_ms,
last_seen_ms: observation.last_seen_ms,
check_count: observation.check_count,
estimated_checks_per_minute: estimate_checks_per_minute(
observation.first_seen_ms,
observation.last_seen_ms,
observation.check_count,
),
})
.collect::<Vec<_>>();
guards.sort_by(|left, right| left.guard_name.cmp(&right.guard_name));
guards
}
fn create_empty_bundle(
key: String,
protected_context: Option<ProtectedContext>,
refresh_rate_seconds: u64,
) -> GuardBundle {
GuardBundle {
key,
guards: HashMap::new(),
ready: false,
etag: String::new(),
protected_context,
refresh_rate_seconds,
}
}
async fn ensure_public_bundle_ready(inner: &Arc<Inner>) -> Result<(), reqwest::Error> {
let ready = inner
.bundles
.read()
.unwrap()
.get(PUBLIC_BUNDLE_KEY)
.map(|bundle| bundle.ready)
.unwrap_or(false);
if ready {
return Ok(());
}
Self::fetch_guards_for_bundle(inner, PUBLIC_BUNDLE_KEY).await
}
async fn ensure_bundle_for_protected_context(
inner: &Arc<Inner>,
protected_context: ProtectedContext,
) -> Result<String, reqwest::Error> {
let bundle_key = Self::protected_context_cache_key(&protected_context);
let ready = {
let mut bundles = inner.bundles.write().unwrap();
bundles
.entry(bundle_key.clone())
.or_insert_with(|| {
Self::create_empty_bundle(
bundle_key.clone(),
Some(protected_context.clone()),
inner.opts.refresh_rate_seconds,
)
})
.ready
};
if ready {
return Ok(bundle_key);
}
Self::fetch_guards_for_bundle(inner, &bundle_key).await?;
Ok(bundle_key)
}
fn protected_context_cache_key(protected_context: &ProtectedContext) -> String {
let mut keys = protected_context
.properties
.keys()
.cloned()
.collect::<Vec<_>>();
keys.sort();
let mut parts = vec![protected_context.signature.clone(), "properties".to_owned()];
for key in keys {
let value = protected_context
.properties
.get(&key)
.cloned()
.unwrap_or_default();
parts.push(format!("{key}={value}"));
}
parts.push(String::new());
parts.push(format!(
"issuedAt={}",
protected_context
.issued_at
.map(|value| value.to_string())
.unwrap_or_default()
));
parts.push(format!(
"expiresAt={}",
protected_context
.expires_at
.map(|value| value.to_string())
.unwrap_or_default()
));
parts.join("\0")
}
fn recompute_refresh_interval(bundles: &HashMap<String, GuardBundle>, default: u64) -> u64 {
bundles
.values()
.map(|bundle| bundle.refresh_rate_seconds)
.filter(|rate| *rate > 0)
.min()
.unwrap_or(default)
}
async fn fetch_guards_for_bundle(
inner: &Arc<Inner>,
bundle_key: &str,
) -> Result<(), reqwest::Error> {
let bundle = {
let bundles = inner.bundles.read().unwrap();
bundles.get(bundle_key).cloned().unwrap_or_else(|| {
Self::create_empty_bundle(
bundle_key.to_owned(),
None,
inner.opts.refresh_rate_seconds,
)
})
};
let url = format!(
"{}/api/v1/guards",
inner.opts.backend_url.trim_end_matches('/')
);
let environment = inner.opts.environment.clone().unwrap_or_default();
let mut request_body = Map::new();
request_body.insert(
"projectClientToken".to_owned(),
Value::String(inner.project_client_token.clone()),
);
request_body.insert("environment".to_owned(), Value::String(environment.clone()));
if let Some(protected_context) = bundle.protected_context.clone() {
let mut protected_context_body = Map::new();
protected_context_body.insert(
"properties".to_owned(),
serde_json::to_value(protected_context.properties)
.unwrap_or(Value::Object(Map::new())),
);
protected_context_body.insert(
"signature".to_owned(),
Value::String(protected_context.signature),
);
if let Some(issued_at) = protected_context.issued_at {
protected_context_body
.insert("issuedAt".to_owned(), Value::Number(issued_at.into()));
}
if let Some(expires_at) = protected_context.expires_at {
protected_context_body
.insert("expiresAt".to_owned(), Value::Number(expires_at.into()));
}
request_body.insert(
"protectedContext".to_owned(),
Value::Object(protected_context_body),
);
}
let mut req = inner
.http
.post(&url)
.bearer_auth(&inner.project_client_token)
.json(&Value::Object(request_body));
if !environment.is_empty() {
req = req.header("X-Liteguard-Environment", &environment);
}
if !bundle.etag.is_empty() {
req = req.header("If-None-Match", &bundle.etag);
}
let response = req.send().await?;
if response.status().as_u16() == 304 {
return Ok(());
}
if !response.status().is_success() {
inner.log(format!(
"guard fetch failed for bundle {:?}: HTTP {}",
bundle_key,
response.status().as_u16()
));
return Ok(());
}
let body: GuardsResponse = response.json().await?;
let effective_refresh_rate = if body.refresh_rate_seconds == 0 {
inner.opts.refresh_rate_seconds
} else {
inner
.opts
.refresh_rate_seconds
.max(body.refresh_rate_seconds)
};
{
let mut bundles = inner.bundles.write().unwrap();
bundles.insert(
bundle.key.clone(),
GuardBundle {
key: bundle.key,
guards: body
.guards
.into_iter()
.map(|guard| (guard.name.clone(), guard))
.collect(),
ready: true,
etag: body.etag,
protected_context: bundle.protected_context,
refresh_rate_seconds: effective_refresh_rate,
},
);
let next_refresh =
Self::recompute_refresh_interval(&bundles, inner.opts.refresh_rate_seconds);
let mut current = inner.current_refresh_rate_seconds.lock().unwrap();
if *current != next_refresh {
*current = next_refresh;
inner.refresh_notify.notify_one();
}
}
Ok(())
}
fn evaluate_guard_in_scope(
inner: &Arc<Inner>,
scope: &Scope,
name: &str,
options: &Options,
emit_signal: bool,
caller: &'static Location<'static>,
) -> GuardEvaluation {
let bundle = {
let bundles = inner.bundles.read().unwrap();
bundles
.get(&scope.bundle_key)
.cloned()
.or_else(|| bundles.get(PUBLIC_BUNDLE_KEY).cloned())
.unwrap_or_else(|| {
Self::create_empty_bundle(
PUBLIC_BUNDLE_KEY.to_owned(),
None,
inner.opts.refresh_rate_seconds,
)
})
};
if !bundle.ready {
return GuardEvaluation {
result: options.fallback.unwrap_or(inner.opts.fallback),
guard: None,
props: None,
signal: None,
};
}
let guard = match bundle.guards.get(name) {
Some(guard) => guard.clone(),
None => {
Self::record_unadopted_guard(inner, name.to_owned());
return GuardEvaluation {
result: true,
guard: None,
props: None,
signal: None,
};
}
};
if !guard.adopted {
Self::record_unadopted_guard(inner, name.to_owned());
return GuardEvaluation {
result: true,
guard: Some(guard),
props: None,
signal: None,
};
}
let mut props = scope.properties.clone();
for (key, value) in &options.properties.0 {
props.insert(key.clone(), value.clone());
}
let applied_rate_limit = Self::apply_rate_limit(
inner,
&guard,
name,
evaluate_guard(&guard, &props),
&props,
emit_signal,
);
let result = applied_rate_limit.result;
let signal = if emit_signal {
Some(Self::buffer_signal_inner(
inner,
BufferedSignalInput {
guard_name: name.to_owned(),
result,
props: props.clone(),
metadata: Self::next_signal_metadata(None),
callsite_id: format!("{}:{}", caller.file(), caller.line()),
kind: "guard_check".into(),
measurement: if Self::is_measurement_enabled(inner, &guard, options) {
Self::capture_guard_check_measurement()
} else {
None
},
rate_limit_decisions: applied_rate_limit.rate_limit_decisions,
},
))
} else {
None
};
GuardEvaluation {
result,
guard: Some(guard),
props: Some(props),
signal,
}
}
fn evaluate_guard_decision(
inner: &Arc<Inner>,
scope: &Scope,
name: &str,
options: &Options,
caller: &'static Location<'static>,
) -> GuardDecision {
let bundle = {
let bundles = inner.bundles.read().unwrap();
bundles
.get(&scope.bundle_key)
.cloned()
.or_else(|| bundles.get(PUBLIC_BUNDLE_KEY).cloned())
.unwrap_or_else(|| {
Self::create_empty_bundle(
PUBLIC_BUNDLE_KEY.to_owned(),
None,
inner.opts.refresh_rate_seconds,
)
})
};
if !bundle.ready {
return GuardDecision {
name: name.to_owned(),
is_open: options.fallback.unwrap_or(inner.opts.fallback),
adopted: false,
reason: GuardDecisionReason::Fallback,
matched_rule_index: -1,
properties: Properties::new(),
};
}
let guard = match bundle.guards.get(name) {
Some(guard) => guard.clone(),
None => {
Self::record_unadopted_guard(inner, name.to_owned());
return GuardDecision {
name: name.to_owned(),
is_open: true,
adopted: false,
reason: GuardDecisionReason::Unadopted,
matched_rule_index: -1,
properties: Properties::new(),
};
}
};
if !guard.adopted {
Self::record_unadopted_guard(inner, name.to_owned());
return GuardDecision {
name: name.to_owned(),
is_open: true,
adopted: false,
reason: GuardDecisionReason::Unadopted,
matched_rule_index: -1,
properties: Properties::new(),
};
}
let mut props = scope.properties.clone();
for (key, value) in &options.properties.0 {
props.insert(key.clone(), value.clone());
}
let detailed = evaluate_guard_detailed(&guard, &props);
let applied_rate_limit =
Self::apply_rate_limit(inner, &guard, name, detailed.result, &props, true);
let result = applied_rate_limit.result;
let reason = if detailed.matched_rule_index >= 0 {
GuardDecisionReason::MatchedRule
} else {
GuardDecisionReason::DefaultValue
};
let matched_idx = detailed.matched_rule_index as i32;
Self::buffer_signal_inner(
inner,
BufferedSignalInput {
guard_name: name.to_owned(),
result,
props: props.clone(),
metadata: Self::next_signal_metadata(None),
callsite_id: format!("{}:{}", caller.file(), caller.line()),
kind: "guard_check".into(),
measurement: if Self::is_measurement_enabled(inner, &guard, options) {
Self::capture_guard_check_measurement()
} else {
None
},
rate_limit_decisions: applied_rate_limit.rate_limit_decisions,
},
);
GuardDecision {
name: name.to_owned(),
is_open: result,
adopted: guard.adopted,
reason,
matched_rule_index: matched_idx,
properties: props,
}
}
fn start_execution_inner(_inner: &Arc<Inner>) -> Execution {
let exec_id = next_signal_id();
let previous = EXECUTION_STATE.with(|slot| slot.borrow().clone());
EXECUTION_STATE.with(|slot| {
*slot.borrow_mut() = Some(ExecutionState {
execution_id: exec_id.clone(),
sequence_number: 0,
last_signal_id: None,
});
});
Execution::new(exec_id, move || {
EXECUTION_STATE.with(|slot| {
*slot.borrow_mut() = previous;
});
})
}
fn record_unadopted_guard(inner: &Arc<Inner>, name: String) {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis()
.min(i64::MAX as u128) as i64;
let mut pending = inner.pending_unadopted_guards.lock().unwrap();
pending
.entry(name)
.and_modify(|observation| {
observation.last_seen_ms = now;
observation.check_count += 1;
})
.or_insert(PendingUnadoptedGuardObservation {
first_seen_ms: now,
last_seen_ms: now,
check_count: 1,
});
}
fn rate_limit_bucket_key(
name: &str,
rate_limit_properties: &[String],
props: &Properties,
) -> String {
if rate_limit_properties.is_empty() {
return name.to_owned();
}
let mut key = String::from(name);
for property_name in rate_limit_properties {
key.push('\0');
key.push_str(property_name);
key.push('=');
if let Some(value) = props.get(property_name) {
key.push_str(&Self::property_value_to_string(value));
}
}
key
}
fn rate_limit_slot_key(
name: &str,
evaluation_target: RateLimitEvaluationTarget,
dry_run_id: Option<&str>,
) -> String {
match evaluation_target {
RateLimitEvaluationTarget::Active => format!("{name}\0active"),
RateLimitEvaluationTarget::DryRun => {
format!("{name}\0dry_run={}", dry_run_id.unwrap_or_default())
}
}
}
fn partition_values(
rate_limit_properties: &[String],
props: &Properties,
) -> HashMap<String, crate::types::PropertyValue> {
let mut values = HashMap::new();
for property_name in rate_limit_properties {
if let Some(value) = props.get(property_name) {
values.insert(property_name.clone(), value.clone());
}
}
values
}
fn build_rate_limit_decision(
evaluation_target: RateLimitEvaluationTarget,
dry_run_id: Option<String>,
outcome: RateLimitDecisionOutcome,
requests_per_minute: i32,
rate_limit_properties: &[String],
props: &Properties,
count_in_window: i32,
) -> RateLimitDecision {
RateLimitDecision {
evaluation_target,
dry_run_id,
outcome,
requests_per_minute,
partition_properties: rate_limit_properties.to_vec(),
partition_values: Self::partition_values(rate_limit_properties, props),
count_in_window,
}
}
fn property_value_to_string(value: &crate::types::PropertyValue) -> String {
match value {
crate::types::PropertyValue::Bool(v) => v.to_string(),
crate::types::PropertyValue::Number(v) => v.to_string(),
crate::types::PropertyValue::Text(v) => v.clone(),
}
}
fn evaluate_rate_limit(
inner: &Arc<Inner>,
slot_key: &str,
limit_per_minute: i32,
rate_limit_properties: &[String],
props: &Properties,
consume: bool,
) -> RateLimitEvaluation {
let key = Self::rate_limit_bucket_key(slot_key, rate_limit_properties, props);
let now = Instant::now();
let mut state = inner.rate_limit_state.lock().unwrap();
let entry = state.entry(key).or_insert_with(|| RateLimitEntry {
window_start: now,
count: 0,
});
if now.duration_since(entry.window_start) >= Duration::from_secs(60) {
entry.window_start = now;
entry.count = 0;
}
let count_in_window = entry.count + 1;
let allowed = entry.count < limit_per_minute;
if consume && allowed {
entry.count = count_in_window;
}
RateLimitEvaluation {
allowed,
count_in_window,
}
}
fn consume_rate_limit(
inner: &Arc<Inner>,
name: &str,
evaluation_target: RateLimitEvaluationTarget,
dry_run_id: Option<&str>,
limit_per_minute: i32,
rate_limit_properties: &[String],
props: &Properties,
) -> RateLimitEvaluation {
Self::evaluate_rate_limit(
inner,
&Self::rate_limit_slot_key(name, evaluation_target, dry_run_id),
limit_per_minute,
rate_limit_properties,
props,
true,
)
}
fn peek_rate_limit(
inner: &Arc<Inner>,
name: &str,
evaluation_target: RateLimitEvaluationTarget,
dry_run_id: Option<&str>,
limit_per_minute: i32,
rate_limit_properties: &[String],
props: &Properties,
) -> RateLimitEvaluation {
Self::evaluate_rate_limit(
inner,
&Self::rate_limit_slot_key(name, evaluation_target, dry_run_id),
limit_per_minute,
rate_limit_properties,
props,
false,
)
}
fn apply_rate_limit(
inner: &Arc<Inner>,
guard: &Guard,
name: &str,
initial_result: bool,
props: &Properties,
emit_signal: bool,
) -> AppliedRateLimit {
let mut rate_limit_decisions = Vec::new();
if !initial_result {
return AppliedRateLimit {
result: initial_result,
rate_limit_decisions,
};
}
let mut result = initial_result;
if let Some(rate_limit) = guard
.rate_limit
.as_ref()
.filter(|rate_limit| rate_limit.requests_per_minute > 0)
{
let evaluation = if emit_signal {
Self::consume_rate_limit(
inner,
name,
RateLimitEvaluationTarget::Active,
None,
rate_limit.requests_per_minute,
&rate_limit.partition_properties,
props,
)
} else {
Self::peek_rate_limit(
inner,
name,
RateLimitEvaluationTarget::Active,
None,
rate_limit.requests_per_minute,
&rate_limit.partition_properties,
props,
)
};
result = evaluation.allowed;
if emit_signal {
rate_limit_decisions.push(Self::build_rate_limit_decision(
RateLimitEvaluationTarget::Active,
None,
if evaluation.allowed {
RateLimitDecisionOutcome::WithinLimit
} else {
RateLimitDecisionOutcome::Limited
},
rate_limit.requests_per_minute,
&rate_limit.partition_properties,
props,
evaluation.count_in_window,
));
}
}
if emit_signal {
if let Some(dry_run_rate_limit) = guard
.dry_run_rate_limit
.as_ref()
.filter(|rate_limit| rate_limit.requests_per_minute > 0)
{
let evaluation = Self::consume_rate_limit(
inner,
name,
RateLimitEvaluationTarget::DryRun,
Some(dry_run_rate_limit.dry_run_id.as_str()),
dry_run_rate_limit.requests_per_minute,
&dry_run_rate_limit.partition_properties,
props,
);
rate_limit_decisions.push(Self::build_rate_limit_decision(
RateLimitEvaluationTarget::DryRun,
Some(dry_run_rate_limit.dry_run_id.clone()),
if evaluation.allowed {
RateLimitDecisionOutcome::WithinLimit
} else {
RateLimitDecisionOutcome::WouldLimit
},
dry_run_rate_limit.requests_per_minute,
&dry_run_rate_limit.partition_properties,
props,
evaluation.count_in_window,
));
}
}
AppliedRateLimit {
result,
rate_limit_decisions,
}
}
fn buffer_signal_inner(inner: &Arc<Inner>, input: BufferedSignalInput) -> Signal {
let ts = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis()
.min(i64::MAX as u128) as i64;
let dropped_signals_since_last = {
let mut dropped = inner.dropped_signals_pending.lock().unwrap();
let pending = *dropped;
*dropped = 0;
pending
};
let signal = Signal {
guard_name: input.guard_name,
result: input.result,
properties: input.props.0,
timestamp_ms: ts,
trace: Client::capture_trace_context(),
signal_id: input.metadata.signal_id,
execution_id: input.metadata.execution_id,
parent_signal_id: input.metadata.parent_signal_id,
sequence_number: input.metadata.sequence_number,
callsite_id: input.callsite_id,
kind: input.kind,
dropped_signals_since_last,
measurement: input.measurement,
rate_limit_decisions: input.rate_limit_decisions,
};
let flush_size = inner.opts.flush_size;
let max_buffered_signals = flush_size * inner.opts.flush_buffer_multiplier;
let mut buffer = inner.signal_buffer.lock().unwrap();
if buffer.len() >= max_buffered_signals {
buffer.remove(0);
*inner.dropped_signals_pending.lock().unwrap() += 1;
}
buffer.push(signal.clone());
if buffer.len() >= flush_size {
drop(buffer);
let inner = Arc::clone(inner);
let runtime = inner.runtime.clone();
runtime.spawn(async move {
let _ = Client::flush_signals_inner(&inner).await;
});
}
signal
}
fn stop_background_tasks(&self) {
if self.is_shutdown.swap(true, Ordering::SeqCst) {
return;
}
self.refresh_task.abort();
self.flush_task.abort();
}
async fn flush_signals_inner(inner: &Arc<Inner>) -> Result<(), reqwest::Error> {
let batch = {
let mut buffer = inner.signal_buffer.lock().unwrap();
std::mem::take(&mut *buffer)
};
let unadopted_guard_observations = Self::take_pending_unadopted_guards(inner);
if batch.is_empty() && unadopted_guard_observations.is_empty() {
return Ok(());
}
let mut first_error = None;
if !batch.is_empty() {
if let Err(err) = Self::flush_signal_batch(inner, batch).await {
first_error = Some(err);
}
}
if !unadopted_guard_observations.is_empty() {
if let Err(err) =
Self::flush_unadopted_guard_batch(inner, unadopted_guard_observations).await
{
if first_error.is_none() {
first_error = Some(err);
}
}
}
if let Some(err) = first_error {
return Err(err);
}
Ok(())
}
async fn flush_signal_batch(
inner: &Arc<Inner>,
batch: Vec<Signal>,
) -> Result<(), reqwest::Error> {
#[derive(serde::Serialize)]
#[serde(rename_all = "camelCase")]
struct Payload<'a> {
project_client_token: &'a str,
environment: String,
signals: &'a [Signal],
}
let payload = Payload {
project_client_token: &inner.project_client_token,
environment: inner.opts.environment.clone().unwrap_or_default(),
signals: &batch,
};
let environment = payload.environment.clone();
let url = format!(
"{}/api/v1/signals",
inner.opts.backend_url.trim_end_matches('/')
);
let mut request = inner
.http
.post(&url)
.bearer_auth(&inner.project_client_token)
.json(&payload);
if !environment.is_empty() {
request = request.header("X-Liteguard-Environment", &environment);
}
let result = request
.send()
.await
.and_then(|response| response.error_for_status());
if let Err(err) = result {
inner.log(format!(
"signal flush failed: {err} for {} signals",
batch.len()
));
let mut buffer = inner.signal_buffer.lock().unwrap();
let old = std::mem::take(&mut *buffer);
*buffer = batch;
buffer.extend(old);
let max_buffered_signals = inner.opts.flush_size * inner.opts.flush_buffer_multiplier;
if buffer.len() > max_buffered_signals {
let overflow = buffer.len() - max_buffered_signals;
*inner.dropped_signals_pending.lock().unwrap() += overflow as i32;
buffer.drain(0..overflow);
}
return Err(err);
}
Ok(())
}
async fn flush_unadopted_guard_batch(
inner: &Arc<Inner>,
observations: Vec<UnadoptedGuardObservation>,
) -> Result<(), reqwest::Error> {
let payload = SendUnadoptedGuardsRequest {
project_client_token: inner.project_client_token.clone(),
environment: inner.opts.environment.clone().unwrap_or_default(),
observations: observations.clone(),
};
let environment = payload.environment.clone();
let url = format!(
"{}/api/v1/unadopted-guards",
inner.opts.backend_url.trim_end_matches('/')
);
let mut request = inner
.http
.post(&url)
.bearer_auth(&inner.project_client_token)
.json(&payload);
if !environment.is_empty() {
request = request.header("X-Liteguard-Environment", &environment);
}
let result = request
.send()
.await
.and_then(|response| response.error_for_status());
if let Err(err) = result {
inner.log(format!(
"unadopted guard flush failed: {err} for {} guards",
observations.len()
));
Self::requeue_pending_unadopted_guards(inner, observations);
return Err(err);
}
Ok(())
}
fn take_pending_unadopted_guards(inner: &Arc<Inner>) -> Vec<UnadoptedGuardObservation> {
let mut pending = inner.pending_unadopted_guards.lock().unwrap();
if pending.is_empty() {
return Vec::new();
}
let mut names = pending
.iter()
.map(|(guard_name, observation)| UnadoptedGuardObservation {
guard_name: guard_name.clone(),
first_seen_ms: observation.first_seen_ms,
last_seen_ms: observation.last_seen_ms,
check_count: observation.check_count,
estimated_checks_per_minute: estimate_checks_per_minute(
observation.first_seen_ms,
observation.last_seen_ms,
observation.check_count,
),
})
.collect::<Vec<_>>();
names.sort_by(|left, right| left.guard_name.cmp(&right.guard_name));
pending.clear();
names
}
fn requeue_pending_unadopted_guards(
inner: &Arc<Inner>,
observations: Vec<UnadoptedGuardObservation>,
) {
if observations.is_empty() {
return;
}
let mut pending = inner.pending_unadopted_guards.lock().unwrap();
for observation in observations {
pending
.entry(observation.guard_name.clone())
.and_modify(|existing| {
existing.first_seen_ms = existing.first_seen_ms.min(observation.first_seen_ms);
existing.last_seen_ms = existing.last_seen_ms.max(observation.last_seen_ms);
existing.check_count += observation.check_count;
})
.or_insert(PendingUnadoptedGuardObservation {
first_seen_ms: observation.first_seen_ms,
last_seen_ms: observation.last_seen_ms,
check_count: observation.check_count,
});
}
}
fn next_signal_metadata(parent_signal_id_override: Option<String>) -> SignalMetadata {
let signal_id = next_signal_id();
EXECUTION_STATE.with(|slot| {
let mut state = slot.borrow_mut();
if let Some(state) = state.as_mut() {
state.sequence_number += 1;
let metadata = SignalMetadata {
signal_id: signal_id.clone(),
execution_id: state.execution_id.clone(),
parent_signal_id: parent_signal_id_override
.or_else(|| state.last_signal_id.clone()),
sequence_number: state.sequence_number,
};
state.last_signal_id = Some(signal_id);
metadata
} else {
SignalMetadata {
signal_id,
execution_id: next_signal_id(),
parent_signal_id: None,
sequence_number: 1,
}
}
})
}
fn is_measurement_enabled(inner: &Arc<Inner>, guard: &Guard, options: &Options) -> bool {
if inner.opts.disable_measurement || options.disable_measurement {
return false;
}
!guard.disable_measurement.unwrap_or(false)
}
fn capture_guard_check_measurement() -> Option<SignalPerformance> {
None
}
fn capture_guard_execution_measurement(
started_at: Instant,
completed: bool,
error_class: Option<String>,
) -> SignalPerformance {
SignalPerformance {
guard_check: None,
guard_execution: Some(GuardExecutionPerformance {
duration_ns: started_at.elapsed().as_nanos().min(i64::MAX as u128) as i64,
rss_end_bytes: None,
heap_used_end_bytes: None,
heap_total_end_bytes: None,
cpu_time_end_ns: None,
gc_count_end: None,
thread_count_end: None,
completed,
error_class,
}),
}
}
fn panic_error_class(_payload: &(dyn std::any::Any + Send)) -> String {
"panic".into()
}
fn capture_trace_context() -> Option<TraceContext> {
let context = opentelemetry::Context::current();
let span = context.span();
let span_context = span.span_context();
if !span_context.is_valid() {
return None;
}
Some(TraceContext {
trace_id: span_context.trace_id().to_string(),
span_id: span_context.span_id().to_string(),
parent_span_id: String::new(),
})
}
}
fn next_signal_id() -> String {
let n = SIGNAL_COUNTER.fetch_add(1, Ordering::Relaxed) + 1;
let ts = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_nanos();
format!("{:x}-{:x}", ts, n)
}
impl Drop for Client {
fn drop(&mut self) {
self.stop_background_tasks();
let pending = match self.inner.signal_buffer.lock() {
Ok(buffer) => buffer.len(),
Err(err) => err.into_inner().len(),
};
if pending > 0 {
self.inner.log(format!(
"client dropped with {pending} unflushed signals, call shutdown().await before dropping to avoid signal loss"
));
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{
GuardDryRunRateLimit, GuardRateLimitConfig, Operator, PropertyValue,
RateLimitDecisionOutcome, RateLimitEvaluationTarget, Rule,
};
use opentelemetry::trace::{Span, TraceContextExt, Tracer, TracerProvider};
use opentelemetry_sdk::trace::SdkTracerProvider;
use std::io::{Read, Write};
use std::net::TcpListener;
use std::thread;
use tokio::runtime::{Builder, Runtime};
fn make_client() -> (Runtime, Client) {
let runtime = Builder::new_multi_thread().enable_all().build().unwrap();
let client = runtime
.block_on(Client::init(
"test-key",
ClientOptions {
backend_url: "http://localhost:9999".into(),
..Default::default()
},
))
.unwrap();
(runtime, client)
}
fn adopted_guard(
name: &str,
default_value: bool,
rules: Vec<Rule>,
rate_limit: Option<GuardRateLimitConfig>,
) -> Guard {
Guard {
name: name.into(),
rules,
default_value,
adopted: true,
rate_limit,
dry_run_rate_limit: None,
disable_measurement: None,
}
}
fn rate_limit(
requests_per_minute: i32,
partition_properties: Vec<&str>,
) -> GuardRateLimitConfig {
GuardRateLimitConfig {
requests_per_minute,
partition_properties: partition_properties
.into_iter()
.map(str::to_owned)
.collect(),
}
}
fn dry_run_rate_limit(
dry_run_id: &str,
requests_per_minute: i32,
partition_properties: Vec<&str>,
) -> GuardDryRunRateLimit {
GuardDryRunRateLimit {
dry_run_id: dry_run_id.into(),
requests_per_minute,
partition_properties: partition_properties
.into_iter()
.map(str::to_owned)
.collect(),
}
}
fn rule(prop: &str, operator: Operator, values: Vec<PropertyValue>, result: bool) -> Rule {
Rule {
property_name: prop.into(),
operator,
values,
result,
enabled: true,
}
}
fn with_props(properties: Properties) -> Options {
Options {
properties,
fallback: None,
disable_measurement: false,
}
}
fn read_request(stream: &mut std::net::TcpStream) -> String {
let mut request = Vec::new();
let mut headers = [0_u8; 4096];
let size = stream.read(&mut headers).unwrap();
request.extend_from_slice(&headers[..size]);
let header_end = request.windows(4).position(|w| w == b"\r\n\r\n").unwrap() + 4;
let header_text = String::from_utf8_lossy(&request[..header_end]);
let content_length = header_text
.lines()
.find_map(|line| line.strip_prefix("Content-Length: "))
.and_then(|value| value.trim().parse::<usize>().ok())
.unwrap_or(0);
let already_read = request.len() - header_end;
if content_length > already_read {
let mut body = vec![0_u8; content_length - already_read];
stream.read_exact(&mut body).unwrap();
request.extend_from_slice(&body);
}
String::from_utf8_lossy(&request).into_owned()
}
fn write_response(stream: &mut std::net::TcpStream, status: u16, body: &str) {
let response = format!(
"HTTP/1.1 {status} TEST\r\nContent-Length: {}\r\nContent-Type: application/json\r\nConnection: close\r\n\r\n{body}",
body.len()
);
stream.write_all(response.as_bytes()).unwrap();
}
#[test]
fn options_with_defaults_normalize_invalid_values() {
let opts = ClientOptions {
environment: None,
fallback: false,
refresh_rate_seconds: 0,
flush_rate_seconds: 0,
flush_size: 0,
backend_url: String::new(),
quiet: true,
http_timeout_seconds: 0,
flush_buffer_multiplier: 0,
disable_measurement: false,
}
.with_defaults();
assert_eq!(opts.refresh_rate_seconds, 60);
assert_eq!(opts.flush_rate_seconds, 60);
assert_eq!(opts.flush_size, 500);
assert_eq!(opts.backend_url, "https://api.liteguard.io");
assert_eq!(opts.http_timeout_seconds, 4);
assert_eq!(opts.flush_buffer_multiplier, 4);
}
#[test]
fn evaluates_rules_using_explicit_scopes_and_per_call_properties() {
let (_runtime, client) = make_client();
client.set_guards_for_testing(vec![
adopted_guard(
"feature.plan",
false,
vec![rule(
"plan",
Operator::Equals,
vec![PropertyValue::Text("pro".into())],
true,
)],
None,
),
adopted_guard(
"feature.canary",
false,
vec![rule(
"userId",
Operator::In,
vec![PropertyValue::Text("canary-1".into())],
true,
)],
None,
),
]);
let scope = client.create_scope(
Properties::new()
.set("plan", "free")
.set("userId", "normal-user"),
);
assert!(!scope.is_open("feature.plan"));
assert!(scope.is_open_with_options(
"feature.canary",
&with_props(Properties::new().set("userId", "canary-1")),
));
}
#[test]
fn scopes_isolate_request_local_properties_across_threads() {
let (_runtime, client) = make_client();
client.set_guards_for_testing(vec![adopted_guard(
"feature.plan",
false,
vec![rule(
"plan",
Operator::Equals,
vec![PropertyValue::Text("pro".into())],
true,
)],
None,
)]);
let free_scope = client.create_scope(Properties::new().set("requestId", "free"));
let pro_scope = client.create_scope(Properties::new().set("requestId", "pro"));
let ready = Arc::new(std::sync::Barrier::new(3));
let release = Arc::new(std::sync::Barrier::new(3));
let free_ready = ready.clone();
let free_release = release.clone();
let free = thread::spawn(move || {
let scope = free_scope.add_properties(Properties::new().set("plan", "free"));
free_ready.wait();
free_release.wait();
scope.is_open("feature.plan")
});
let pro_ready = ready.clone();
let pro_release = release.clone();
let pro = thread::spawn(move || {
let scope = pro_scope.add_properties(Properties::new().set("plan", "pro"));
pro_ready.wait();
pro_release.wait();
scope.is_open("feature.plan")
});
ready.wait();
release.wait();
let results = [free.join().unwrap(), pro.join().unwrap()];
assert!(results.contains(&false));
assert!(results.contains(&true));
}
#[test]
fn protected_context_binds_to_derived_scopes_and_caches_bundles() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let address = listener.local_addr().unwrap();
let request_bodies = Arc::new(Mutex::new(Vec::<String>::new()));
let request_bodies_for_server = Arc::clone(&request_bodies);
let server = thread::spawn(move || {
for _ in 0..3 {
let (mut stream, _) = listener.accept().unwrap();
let request = read_request(&mut stream);
request_bodies_for_server
.lock()
.unwrap()
.push(request.clone());
let body = if request.contains("\"protectedContext\"") {
"{\"guards\":[{\"name\":\"private.dashboard\",\"rules\":[],\"defaultValue\":true,\"adopted\":true}],\"refreshRateSeconds\":30,\"etag\":\"protected\"}"
} else {
"{\"guards\":[{\"name\":\"private.dashboard\",\"rules\":[],\"defaultValue\":false,\"adopted\":true}],\"refreshRateSeconds\":30,\"etag\":\"public\"}"
};
write_response(&mut stream, 200, body);
}
});
let runtime = Builder::new_multi_thread().enable_all().build().unwrap();
let client = runtime
.block_on(Client::init(
"test-key",
ClientOptions {
backend_url: format!("http://{address}"),
refresh_rate_seconds: 3600,
flush_rate_seconds: 3600,
..Default::default()
},
))
.unwrap();
let public_scope = client.create_scope(Properties::new());
let protected_context_a = ProtectedContext {
properties: HashMap::from([(String::from("email"), String::from("alice@acme.com"))]),
signature: "sig-123".into(),
issued_at: None,
expires_at: None,
};
let protected_context_b = ProtectedContext {
properties: HashMap::from([(String::from("email"), String::from("bob@acme.com"))]),
signature: "sig-456".into(),
issued_at: None,
expires_at: None,
};
let protected_scope = runtime
.block_on(public_scope.bind_protected_context(protected_context_a.clone()))
.unwrap();
let cached_scope = runtime
.block_on(
client
.create_scope(Properties::new())
.bind_protected_context(protected_context_a),
)
.unwrap();
let other_scope = runtime
.block_on(
client
.create_scope(Properties::new())
.bind_protected_context(protected_context_b),
)
.unwrap();
assert!(!public_scope.is_open("private.dashboard"));
assert!(protected_scope.is_open("private.dashboard"));
assert!(cached_scope.is_open("private.dashboard"));
assert!(other_scope.is_open("private.dashboard"));
assert!(!runtime
.block_on(protected_scope.clear_protected_context())
.unwrap()
.is_open("private.dashboard"));
assert_eq!(client.known_bundle_count_for_testing(), 3);
drop(client);
server.join().unwrap();
let requests = request_bodies.lock().unwrap();
assert_eq!(requests.len(), 3);
assert!(requests[0].starts_with("POST /api/v1/guards"));
assert!(requests[1].contains("alice@acme.com"));
assert!(requests[2].contains("bob@acme.com"));
}
#[test]
fn refreshes_all_cached_bundles_using_shortest_interval() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let address = listener.local_addr().unwrap();
let server = thread::spawn(move || {
for body in [
"{\"guards\":[],\"refreshRateSeconds\":90,\"etag\":\"public\"}",
"{\"guards\":[],\"refreshRateSeconds\":45,\"etag\":\"protected\"}",
] {
let (mut stream, _) = listener.accept().unwrap();
let _ = read_request(&mut stream);
write_response(&mut stream, 200, body);
}
});
let runtime = Builder::new_multi_thread().enable_all().build().unwrap();
let client = runtime
.block_on(Client::init(
"test-key",
ClientOptions {
backend_url: format!("http://{address}"),
refresh_rate_seconds: 30,
flush_rate_seconds: 3600,
..Default::default()
},
))
.unwrap();
assert_eq!(client.current_refresh_rate_for_testing(), 90);
let _ = runtime
.block_on(
client
.create_scope(Properties::new())
.bind_protected_context(ProtectedContext {
properties: HashMap::from([(
String::from("email"),
String::from("alice@acme.com"),
)]),
signature: "sig-123".into(),
issued_at: None,
expires_at: None,
}),
)
.unwrap();
assert_eq!(client.current_refresh_rate_for_testing(), 45);
drop(client);
server.join().unwrap();
}
#[test]
fn server_can_speed_up_refreshes_below_client_default() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let address = listener.local_addr().unwrap();
let server = thread::spawn(move || {
let (mut stream, _) = listener.accept().unwrap();
let _ = read_request(&mut stream);
write_response(
&mut stream,
200,
"{\"guards\":[],\"refreshRateSeconds\":5,\"etag\":\"public\"}",
);
});
let runtime = Builder::new_multi_thread().enable_all().build().unwrap();
let client = runtime
.block_on(Client::init(
"test-key",
ClientOptions {
backend_url: format!("http://{address}"),
refresh_rate_seconds: 30,
flush_rate_seconds: 3600,
..Default::default()
},
))
.unwrap();
assert_eq!(client.current_refresh_rate_for_testing(), 30);
drop(client);
server.join().unwrap();
}
#[test]
fn fallback_is_used_before_initial_load() {
let runtime = Builder::new_multi_thread().enable_all().build().unwrap();
let fallback_true = runtime
.block_on(Client::init(
"test-key",
ClientOptions {
backend_url: "http://localhost:9999".into(),
fallback: true,
..Default::default()
},
))
.unwrap();
assert!(fallback_true.is_open("unknown"));
let fallback_false = runtime
.block_on(Client::init(
"test-key",
ClientOptions {
backend_url: "http://localhost:9999".into(),
fallback: false,
..Default::default()
},
))
.unwrap();
assert!(!fallback_false.is_open("unknown"));
}
#[test]
fn unadopted_guard_observations_are_aggregated() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let address = listener.local_addr().unwrap();
let server = thread::spawn(move || {
for status in [200, 204] {
let (mut stream, _) = listener.accept().unwrap();
let request = read_request(&mut stream);
let body = if status == 200 {
"{\"guards\":[],\"refreshRateSeconds\":30,\"etag\":\"etag\"}"
} else {
assert!(request.starts_with("POST /api/v1/unadopted-guards"));
assert!(request.contains("\"guardName\":\"g\""));
assert!(request.contains("\"checkCount\":2"));
""
};
write_response(&mut stream, status, body);
}
});
let runtime = Builder::new_multi_thread().enable_all().build().unwrap();
let client = runtime
.block_on(Client::init(
"test-key",
ClientOptions {
backend_url: format!("http://{address}"),
..Default::default()
},
))
.unwrap();
client.set_guards_for_testing(vec![Guard {
name: "g".into(),
rules: vec![],
default_value: false,
adopted: false,
rate_limit: None,
dry_run_rate_limit: None,
disable_measurement: None,
}]);
assert!(client.is_open("g"));
assert!(client.is_open("g"));
let pending = client.pending_unadopted_guards_for_testing();
assert_eq!(pending.len(), 1);
assert_eq!(pending[0].guard_name, "g");
assert_eq!(pending[0].check_count, 2);
runtime.block_on(client.flush()).unwrap();
assert!(client.pending_unadopted_guards_for_testing().is_empty());
drop(client);
server.join().unwrap();
}
#[test]
fn signal_includes_tracing_metadata() {
let (_runtime, client) = make_client();
client.set_guards_for_testing(vec![adopted_guard("g", true, vec![], None)]);
assert!(client.is_open("g"));
let buffer = client.inner.signal_buffer.lock().unwrap();
let signal = &buffer[0];
assert!(!signal.signal_id.is_empty());
assert!(!signal.execution_id.is_empty());
assert_eq!(signal.sequence_number, 1);
assert!(!signal.callsite_id.is_empty());
assert_eq!(signal.kind, "guard_check");
assert!(signal.measurement.is_none());
}
#[test]
fn with_execution_links_signals() {
let (_runtime, client) = make_client();
client.set_guards_for_testing(vec![
adopted_guard("a", true, vec![], None),
adopted_guard("b", true, vec![], None),
]);
client.with_execution(|| {
assert!(client.is_open("a"));
assert!(client.is_open("b"));
});
let buffer = client.inner.signal_buffer.lock().unwrap();
assert_eq!(buffer.len(), 2);
assert_eq!(buffer[0].execution_id, buffer[1].execution_id);
assert_eq!(
buffer[1].parent_signal_id.as_deref(),
Some(buffer[0].signal_id.as_str())
);
}
#[test]
fn execute_if_open_emits_execution_telemetry() {
let (_runtime, client) = make_client();
client.set_guards_for_testing(vec![adopted_guard("g", true, vec![], None)]);
let result = client.execute_if_open("g", || 42);
assert_eq!(result, Some(42));
let buffer = client.inner.signal_buffer.lock().unwrap();
assert_eq!(buffer.len(), 2);
assert_eq!(buffer[1].kind, "guard_execution");
assert_eq!(
buffer[1].parent_signal_id.as_deref(),
Some(buffer[0].signal_id.as_str())
);
let measurement = buffer[1].measurement.as_ref().unwrap();
let execution = measurement.guard_execution.as_ref().unwrap();
assert!(execution.completed);
assert!(execution.error_class.is_none());
}
#[test]
fn try_execute_if_open_records_failed_execution_results() {
let (_runtime, client) = make_client();
client.set_guards_for_testing(vec![adopted_guard("g", true, vec![], None)]);
let result = client.try_execute_if_open("g", || -> Result<i32, std::io::Error> {
Err(std::io::Error::other("boom"))
});
assert!(result.is_err());
let buffer = client.inner.signal_buffer.lock().unwrap();
assert_eq!(buffer.len(), 2);
let measurement = buffer[1].measurement.as_ref().unwrap();
let execution = measurement.guard_execution.as_ref().unwrap();
assert!(!execution.completed);
assert_eq!(
execution.error_class.as_deref(),
Some("std::io::error::Error")
);
}
#[test]
fn execution_context_clears_after_panics() {
let (_runtime, client) = make_client();
client.set_guards_for_testing(vec![adopted_guard("g", true, vec![], None)]);
let panic = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
client.execute_if_open("g", || -> i32 { panic!("boom") });
}));
assert!(panic.is_err());
assert!(client.is_open("g"));
let buffer = client.inner.signal_buffer.lock().unwrap();
assert_eq!(buffer.len(), 3);
assert_ne!(buffer[2].execution_id, buffer[0].execution_id);
assert!(buffer[2].parent_signal_id.is_none());
assert_eq!(buffer[2].sequence_number, 1);
}
#[test]
fn signals_capture_active_trace_context() {
let (_runtime, client) = make_client();
client.set_guards_for_testing(vec![adopted_guard("g", true, vec![], None)]);
let provider = SdkTracerProvider::builder().build();
let tracer = provider.tracer("liteguard-test");
let span = tracer.start("guard-check");
let expected_trace_id = span.span_context().trace_id().to_string();
let expected_span_id = span.span_context().span_id().to_string();
let context = opentelemetry::Context::current_with_span(span);
let guard = context.attach();
assert!(client.is_open("g"));
drop(guard);
let _ = provider.shutdown();
let buffer = client.inner.signal_buffer.lock().unwrap();
let trace = buffer[0]
.trace
.as_ref()
.expect("trace context should be captured");
assert_eq!(trace.trace_id, expected_trace_id);
assert_eq!(trace.span_id, expected_span_id);
assert!(trace.parent_span_id.is_empty());
}
#[test]
fn measurement_is_omitted_when_disabled_by_client_check_or_guard() {
let runtime = Builder::new_multi_thread().enable_all().build().unwrap();
let client = runtime
.block_on(Client::init(
"test-key",
ClientOptions {
backend_url: "http://localhost:9999".into(),
disable_measurement: true,
..Default::default()
},
))
.unwrap();
client.set_guards_for_testing(vec![adopted_guard("g", true, vec![], None)]);
assert!(client.is_open("g"));
assert!(client.inner.signal_buffer.lock().unwrap()[0]
.measurement
.is_none());
let (_runtime, client) = make_client();
client.set_guards_for_testing(vec![adopted_guard("g", true, vec![], None)]);
assert!(client.is_open_with_options(
"g",
&Options {
disable_measurement: true,
..Default::default()
},
));
assert!(client.inner.signal_buffer.lock().unwrap()[0]
.measurement
.is_none());
let (_runtime, client) = make_client();
client.set_guards_for_testing(vec![Guard {
name: "g".into(),
rules: vec![],
default_value: true,
adopted: true,
rate_limit: None,
dry_run_rate_limit: None,
disable_measurement: Some(true),
}]);
assert_eq!(client.execute_if_open("g", || 1), Some(1));
let buffer = client.inner.signal_buffer.lock().unwrap();
assert!(buffer[0].measurement.is_none());
assert!(buffer[1].measurement.is_none());
}
#[test]
fn drops_oldest_signals_when_buffer_is_full() {
let runtime = Builder::new_current_thread().enable_all().build().unwrap();
let client = runtime
.block_on(Client::init(
"test-key",
ClientOptions {
backend_url: "http://localhost:9999".into(),
flush_size: 2,
flush_buffer_multiplier: 2,
..Default::default()
},
))
.unwrap();
client.set_guards_for_testing(vec![adopted_guard("g", true, vec![], None)]);
for _ in 0..5 {
assert!(client.is_open("g"));
}
let oldest_signal_id = {
let buffer = client.inner.signal_buffer.lock().unwrap();
assert_eq!(buffer.len(), 4);
buffer[0].signal_id.clone()
};
assert!(client.is_open("g"));
let buffer = client.inner.signal_buffer.lock().unwrap();
assert_eq!(buffer.len(), 4);
assert!(buffer
.iter()
.all(|signal| signal.signal_id != oldest_signal_id));
assert!(buffer
.last()
.map(|signal| signal.dropped_signals_since_last >= 1)
.unwrap_or(false));
}
#[test]
fn rate_limit_applies_property_scoped_buckets_across_scopes() {
let (_runtime, client) = make_client();
client.set_guards_for_testing(vec![adopted_guard(
"rl",
true,
vec![],
Some(rate_limit(1, vec!["region"])),
)]);
let us_scope = client.create_scope(Properties::new().set("region", "us-east"));
let eu_scope = client.create_scope(Properties::new().set("region", "eu-west"));
assert!(us_scope.is_open("rl"));
assert!(!us_scope.is_open("rl"));
assert!(!us_scope.peek_is_open("rl"));
assert!(eu_scope.is_open("rl"));
}
#[test]
fn closed_evaluations_do_not_consume_quota() {
let (_runtime, client) = make_client();
client.set_guards_for_testing(vec![adopted_guard(
"rl",
false,
vec![rule(
"plan",
Operator::Equals,
vec![PropertyValue::Text("pro".into())],
true,
)],
Some(rate_limit(1, vec![])),
)]);
let free_scope = client.create_scope(Properties::new().set("plan", "free"));
assert!(!free_scope.is_open("rl"));
assert!(!free_scope.is_open("rl"));
let pro_scope = client.create_scope(Properties::new().set("plan", "pro"));
assert!(pro_scope.is_open("rl"));
assert!(!pro_scope.is_open("rl"));
}
#[test]
fn rate_limit_records_active_and_dry_run_decisions() {
let (_runtime, client) = make_client();
client.set_guards_for_testing(vec![Guard {
name: "rl".into(),
rules: vec![],
default_value: true,
adopted: true,
rate_limit: Some(rate_limit(1, vec!["region"])),
dry_run_rate_limit: Some(dry_run_rate_limit("preview-1", 1, vec!["region"])),
disable_measurement: None,
}]);
let scope = client.create_scope(Properties::new().set("region", "us-east"));
assert!(scope.is_open("rl"));
assert!(!scope.is_open("rl"));
let buffer = client.inner.signal_buffer.lock().unwrap();
let signal = buffer.last().unwrap();
assert_eq!(signal.rate_limit_decisions.len(), 2);
let active_decision = &signal.rate_limit_decisions[0];
assert_eq!(
active_decision.evaluation_target,
RateLimitEvaluationTarget::Active
);
assert_eq!(active_decision.outcome, RateLimitDecisionOutcome::Limited);
assert_eq!(active_decision.requests_per_minute, 1);
assert_eq!(
active_decision.partition_properties,
vec!["region".to_string()]
);
assert_eq!(
active_decision.partition_values,
Properties::new().set("region", "us-east").0
);
assert_eq!(active_decision.count_in_window, 2);
let dry_run_decision = &signal.rate_limit_decisions[1];
assert_eq!(
dry_run_decision.evaluation_target,
RateLimitEvaluationTarget::DryRun
);
assert_eq!(dry_run_decision.dry_run_id.as_deref(), Some("preview-1"));
assert_eq!(
dry_run_decision.outcome,
RateLimitDecisionOutcome::WouldLimit
);
assert_eq!(dry_run_decision.requests_per_minute, 1);
assert_eq!(
dry_run_decision.partition_properties,
vec!["region".to_string()]
);
assert_eq!(
dry_run_decision.partition_values,
Properties::new().set("region", "us-east").0
);
assert_eq!(dry_run_decision.count_in_window, 2);
}
#[test]
fn environment_and_protected_context_are_sent_with_normalized_paths() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let address = listener.local_addr().unwrap();
let requests = Arc::new(Mutex::new(Vec::<String>::new()));
let requests_for_server = Arc::clone(&requests);
let server = thread::spawn(move || {
for body in [
"{\"guards\":[],\"refreshRateSeconds\":30,\"etag\":\"etag-1\"}",
"",
"",
"{\"guards\":[],\"refreshRateSeconds\":30,\"etag\":\"etag-2\"}",
] {
let (mut stream, _) = listener.accept().unwrap();
let request = read_request(&mut stream);
requests_for_server.lock().unwrap().push(request);
write_response(&mut stream, if body.is_empty() { 204 } else { 200 }, body);
}
});
let runtime = Builder::new_multi_thread().enable_all().build().unwrap();
let client = runtime
.block_on(Client::init(
"test-key",
ClientOptions {
backend_url: format!("http://{address}/"),
environment: Some("production".into()),
refresh_rate_seconds: 3600,
flush_rate_seconds: 3600,
..Default::default()
},
))
.unwrap();
client.set_guards_for_testing(vec![
adopted_guard("open", true, vec![], None),
Guard {
name: "unadopted".into(),
rules: vec![],
default_value: false,
adopted: false,
rate_limit: None,
dry_run_rate_limit: None,
disable_measurement: None,
},
]);
assert!(client.is_open("open"));
assert!(client.is_open("unadopted"));
runtime.block_on(client.flush()).unwrap();
let scope = runtime
.block_on(
client
.create_scope(Properties::new())
.bind_protected_context(ProtectedContext {
properties: HashMap::from([(
String::from("email"),
String::from("alice@acme.com"),
)]),
signature: "sig-123".into(),
issued_at: None,
expires_at: None,
}),
)
.unwrap();
assert!(scope.protected_context().is_some());
runtime.block_on(scope.clear_protected_context()).unwrap();
drop(client);
server.join().unwrap();
let requests = requests.lock().unwrap();
assert!(requests[0].starts_with("POST /api/v1/guards"));
assert!(requests[0]
.to_ascii_lowercase()
.contains("x-liteguard-environment: production"));
assert!(requests[0].contains("\"environment\":\"production\""));
assert!(requests[1].starts_with("POST /api/v1/signals"));
assert!(requests[1]
.to_ascii_lowercase()
.contains("x-liteguard-environment: production"));
assert!(requests[1].contains("\"environment\":\"production\""));
assert!(requests[1].contains("\"measurement\""));
assert!(requests[2].starts_with("POST /api/v1/unadopted-guards"));
assert!(requests[2]
.to_ascii_lowercase()
.contains("x-liteguard-environment: production"));
assert!(requests[2].contains("\"environment\":\"production\""));
assert!(requests[3].starts_with("POST /api/v1/guards"));
assert!(requests[3]
.to_ascii_lowercase()
.contains("x-liteguard-environment: production"));
assert!(requests[3].contains("\"protectedContext\""));
assert!(requests[3].contains("\"signature\":\"sig-123\""));
}
#[test]
fn flush_requeues_on_http_status_error() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let address = listener.local_addr().unwrap();
let server = thread::spawn(move || {
for status in [200, 500, 204] {
let (mut stream, _) = listener.accept().unwrap();
let _ = read_request(&mut stream);
let body = if status == 200 {
"{\"guards\":[],\"refreshRateSeconds\":30,\"etag\":\"etag\"}"
} else {
""
};
write_response(&mut stream, status, body);
}
});
let runtime = Builder::new_multi_thread().enable_all().build().unwrap();
let client = runtime
.block_on(Client::init(
"test-key",
ClientOptions {
backend_url: format!("http://{address}"),
..Default::default()
},
))
.unwrap();
client.set_guards_for_testing(vec![adopted_guard("g", true, vec![], None)]);
assert!(client.is_open("g"));
assert!(runtime.block_on(client.flush()).is_err());
assert_eq!(client.inner.signal_buffer.lock().unwrap().len(), 1);
assert!(runtime.block_on(client.flush()).is_ok());
assert_eq!(client.inner.signal_buffer.lock().unwrap().len(), 0);
drop(client);
server.join().unwrap();
}
#[test]
fn failed_flush_requeue_keeps_newest_signals_when_buffer_overflows() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let address = listener.local_addr().unwrap();
let server = thread::spawn(move || {
for status in [200, 500] {
let (mut stream, _) = listener.accept().unwrap();
let _ = read_request(&mut stream);
let body = if status == 200 {
"{\"guards\":[],\"refreshRateSeconds\":30,\"etag\":\"etag\"}"
} else {
""
};
write_response(&mut stream, status, body);
}
});
let runtime = Builder::new_multi_thread().enable_all().build().unwrap();
let client = runtime
.block_on(Client::init(
"test-key",
ClientOptions {
backend_url: format!("http://{address}"),
flush_size: 2,
flush_buffer_multiplier: 2,
..Default::default()
},
))
.unwrap();
let make_signal = |id: &str| Signal {
guard_name: "g".into(),
result: true,
properties: HashMap::new(),
timestamp_ms: 0,
trace: None,
signal_id: id.into(),
execution_id: "exec".into(),
parent_signal_id: None,
sequence_number: 0,
callsite_id: "callsite".into(),
kind: "guard_check".into(),
dropped_signals_since_last: 0,
measurement: None,
rate_limit_decisions: vec![],
};
{
let mut buffer = client.inner.signal_buffer.lock().unwrap();
*buffer = vec![make_signal("s3"), make_signal("s4"), make_signal("s5")];
}
let result = runtime.block_on(Client::flush_signal_batch(
&client.inner,
vec![make_signal("s1"), make_signal("s2")],
));
assert!(result.is_err());
let ids = client
.inner
.signal_buffer
.lock()
.unwrap()
.iter()
.map(|signal| signal.signal_id.clone())
.collect::<Vec<_>>();
assert_eq!(ids, vec!["s2", "s3", "s4", "s5"]);
assert_eq!(*client.inner.dropped_signals_pending.lock().unwrap(), 1);
drop(client);
server.join().unwrap();
}
}