use std::sync::atomic::{AtomicU64, AtomicU8, Ordering};
use std::sync::Arc;
use std::time::Duration;
use reqwest::Client;
use serde_json::json;
use tokio::time::Instant;
use tracing::{debug, warn};
use crate::error::Error;
use crate::types::*;
const HINT_HEADER: &str = "X-Erebyx-Hint";
const AUTO_FIRED_HEADER: &str = "X-Erebyx-Auto-Fired";
const SESSION_ID_HEADER: &str = "X-Erebyx-Session-Id";
fn parse_csv_header(value: Option<&reqwest::header::HeaderValue>) -> Vec<String> {
match value.and_then(|v| v.to_str().ok()) {
Some(raw) if !raw.is_empty() => raw
.split(',')
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect(),
_ => Vec::new(),
}
}
const DEFAULT_API_URL: &str = "https://core.erebyx.com";
const MAX_RESPONSE_BYTES: u64 = 10 * 1024 * 1024;
fn is_safe_url(url: &str) -> bool {
if url.starts_with("https://") {
return true;
}
if let Some(rest) = url.strip_prefix("http://") {
let host_part = rest.split('/').next().unwrap_or("");
let host = host_part.split(':').next().unwrap_or("");
return matches!(host, "localhost" | "127.0.0.1" | "::1");
}
false
}
const CIRCUIT_BREAK_THRESHOLD: u64 = 3;
const CIRCUIT_STATE_CLOSED: u8 = 0;
const CIRCUIT_STATE_OPEN: u8 = 1;
const CIRCUIT_STATE_HALF_OPEN: u8 = 2;
const CIRCUIT_COOLDOWN: Duration = Duration::from_secs(30);
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10);
#[derive(Clone)]
pub struct Memory {
inner: Arc<MemoryInner>,
}
struct MemoryInner {
client: Client,
api_url: String,
api_key: String,
instance_id: String,
session_id: String,
passphrase: Option<String>,
consecutive_failures: AtomicU64,
circuit_state: AtomicU8,
circuit_opened_at: std::sync::Mutex<Option<Instant>>,
circuit_cooldown: Duration,
}
impl Memory {
pub fn new(api_key: &str) -> Result<Self, Error> {
Self::builder(api_key).build()
}
pub fn from_env() -> Result<Self, Error> {
let api_key = std::env::var("EREBYX_API_KEY")
.map_err(|_| Error::Config("EREBYX_API_KEY not set".into()))?;
let api_url = std::env::var("EREBYX_API_URL").unwrap_or_else(|_| DEFAULT_API_URL.into());
let instance_id = std::env::var("EREBYX_INSTANCE_ID").unwrap_or_else(|_| "default".into());
let passphrase = std::env::var("EREBYX_PASSPHRASE")
.ok()
.filter(|s| !s.trim().is_empty());
let mut builder = Self::builder(&api_key)
.api_url(&api_url)
.instance_id(&instance_id);
if let Some(p) = passphrase {
builder = builder.passphrase(&p);
}
builder.build()
}
pub fn builder(api_key: &str) -> MemoryBuilder {
MemoryBuilder {
api_key: api_key.to_string(),
api_url: DEFAULT_API_URL.to_string(),
instance_id: "default".to_string(),
session_id: uuid::Uuid::new_v4().to_string(),
passphrase: None,
timeout: DEFAULT_TIMEOUT,
circuit_cooldown: CIRCUIT_COOLDOWN,
}
}
pub fn save<'a>(&'a self, content: &str, category: &str) -> SaveBuilder<'a> {
SaveBuilder {
client: self,
content: content.to_string(),
category: category.to_string(),
title: None,
anchors: None,
importance: None,
memory_type: None,
}
}
pub fn search<'a>(&'a self, query: &str) -> SearchBuilder<'a> {
SearchBuilder {
client: self,
query: query.to_string(),
limit: None,
hint_anchors: None,
time_range: None,
types: None,
}
}
pub fn wrap_up<'a>(&'a self, what_we_built: &str, whats_next: &str) -> WrapUpBuilder<'a> {
WrapUpBuilder {
client: self,
what_we_built: what_we_built.to_string(),
whats_next: whats_next.to_string(),
diary: None,
anchors: None,
energy: None,
memories: None,
}
}
pub fn restore_identity<'a>(&'a self) -> RestoreIdentityBuilder<'a> {
RestoreIdentityBuilder {
client: self,
detail_level: None,
include_guide: None,
limit: None,
}
}
pub fn load_context<'a>(&'a self) -> LoadContextBuilder<'a> {
LoadContextBuilder {
client: self,
anchors: None,
mode: None,
specialization_name: None,
detail_level: None,
load_priority: None,
}
}
pub(crate) async fn execute_save(&self, builder: SaveBuilder<'_>) -> Result<SaveResult, Error> {
self.check_circuit()?;
let mut body = json!({
"content": builder.content,
"category": builder.category,
});
if let Some(title) = &builder.title {
body["title"] = json!(title);
}
if let Some(anchors) = &builder.anchors {
body["anchors"] = json!(anchors);
}
if let Some(importance) = builder.importance {
body["importance"] = json!(importance);
}
if let Some(memory_type) = &builder.memory_type {
body["type"] = json!(memory_type);
}
let (response, hints, auto_fired) = self.post("/v0/memory/store", &body).await?;
self.record_success();
let mut result: SaveResult = serde_json::from_value(response)?;
result.hints = hints;
result.auto_fired = auto_fired;
Ok(result)
}
pub(crate) async fn execute_search(
&self,
builder: SearchBuilder<'_>,
) -> Result<SearchResult, Error> {
self.check_circuit()?;
let mut body = json!({
"query": builder.query,
});
if let Some(limit) = builder.limit {
body["limit"] = json!(limit);
}
if let Some(anchors) = &builder.hint_anchors {
body["hint_anchors"] = json!(anchors);
}
if let Some(range) = &builder.time_range {
body["time_range"] = json!(range);
}
if let Some(types) = &builder.types {
body["types"] = json!(types);
}
let (response, hints, auto_fired) = self.post("/v0/memory/remember", &body).await?;
self.record_success();
let mut result: SearchResult = serde_json::from_value(response)?;
result.hints = hints;
result.auto_fired = auto_fired;
Ok(result)
}
pub(crate) async fn execute_wrap_up(
&self,
builder: WrapUpBuilder<'_>,
) -> Result<WrapUpResult, Error> {
self.check_circuit()?;
let mut body = json!({
"session_id": &self.inner.session_id,
"what_we_built": builder.what_we_built,
"whats_next": builder.whats_next,
});
if let Some(diary) = &builder.diary {
body["diary"] = json!(diary);
}
if let Some(anchors) = &builder.anchors {
body["anchors"] = json!(anchors);
}
if let Some(energy) = &builder.energy {
body["energy"] = json!(energy);
}
if let Some(memories) = &builder.memories {
body["memories"] = json!(memories);
}
let (response, hints, auto_fired) = self.post("/v0/session/wrap-up", &body).await?;
self.record_success();
let mut result: WrapUpResult = serde_json::from_value(response)?;
result.hints = hints;
result.auto_fired = auto_fired;
Ok(result)
}
pub(crate) async fn execute_restore_identity(
&self,
builder: RestoreIdentityBuilder<'_>,
) -> Result<RestoreIdentityResult, Error> {
self.check_circuit()?;
let mut body = serde_json::Map::new();
if let Some(limit) = builder.limit {
body.insert("limit".into(), json!(limit));
}
if let Some(include) = builder.include_guide {
body.insert("include_guide".into(), json!(include));
}
if let Some(level) = &builder.detail_level {
body.insert("detail_level".into(), json!(level));
}
let body_value = serde_json::Value::Object(body);
let (response, hints, auto_fired) = self.post("/v0/identity/restore", &body_value).await?;
self.record_success();
let mut result: RestoreIdentityResult = serde_json::from_value(response)?;
result.hints = hints;
result.auto_fired = auto_fired;
Ok(result)
}
pub(crate) async fn execute_load_context(
&self,
builder: LoadContextBuilder<'_>,
) -> Result<LoadContextResult, Error> {
self.check_circuit()?;
let mut body = serde_json::Map::new();
if let Some(anchors) = &builder.anchors {
body.insert("anchors".into(), json!(anchors));
} else {
body.insert("anchors".into(), json!(Vec::<String>::new()));
}
if let Some(mode) = &builder.mode {
body.insert("mode".into(), json!(mode));
}
if let Some(spec) = &builder.specialization_name {
body.insert("specialization_name".into(), json!(spec));
}
if let Some(level) = &builder.detail_level {
body.insert("detail_level".into(), json!(level));
}
if let Some(priority) = &builder.load_priority {
body.insert("load_priority".into(), json!(priority));
}
let body_value = serde_json::Value::Object(body);
let (response, hints, auto_fired) = self.post("/v0/session/load", &body_value).await?;
self.record_success();
let mut result: LoadContextResult = serde_json::from_value(response)?;
result.hints = hints;
result.auto_fired = auto_fired;
Ok(result)
}
async fn post(
&self,
path: &str,
body: &serde_json::Value,
) -> Result<(serde_json::Value, Vec<String>, Vec<String>), Error> {
let url = format!("{}{}", self.inner.api_url.trim_end_matches('/'), path);
debug!(url = %url, "erebyx-sdk POST");
let mut rb = self
.inner
.client
.post(&url)
.header("Content-Type", "application/json")
.bearer_auth(&self.inner.api_key)
.header("X-Instance-ID", &self.inner.instance_id)
.header(SESSION_ID_HEADER, &self.inner.session_id);
if let Some(ref p) = self.inner.passphrase {
rb = rb.header("X-Passphrase", p);
}
let response = rb.json(body).send().await.map_err(|e| {
self.record_failure();
Error::Network(e)
})?;
let status = response.status().as_u16();
if status == 401 || status == 403 {
return Err(Error::AuthenticationFailed(format!("HTTP {status}")));
}
if status == 429 {
let retry_after = response
.headers()
.get("retry-after")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse().ok())
.unwrap_or(10);
return Err(Error::RateLimit {
retry_after_secs: retry_after,
});
}
if status == 404 {
return Err(Error::NotFound(path.to_string()));
}
if (400..500).contains(&status) {
let text = match response.text().await {
Ok(t) => t,
Err(e) => {
self.record_failure();
return Err(Error::Network(e));
}
};
return Err(Error::Validation(text));
}
if status >= 500 {
self.record_failure();
let text = match response.text().await {
Ok(t) => t,
Err(e) => return Err(Error::Network(e)),
};
return Err(Error::Server {
status,
message: text,
});
}
if let Some(len) = response.content_length() {
if len > MAX_RESPONSE_BYTES {
return Err(Error::Server {
status,
message: format!(
"Response body too large ({} bytes; cap is {})",
len, MAX_RESPONSE_BYTES
),
});
}
}
let hints = parse_csv_header(response.headers().get(HINT_HEADER));
let auto_fired = parse_csv_header(response.headers().get(AUTO_FIRED_HEADER));
let json: serde_json::Value = response.json().await.map_err(|e| {
self.record_failure();
Error::Network(e)
})?;
Ok((json, hints, auto_fired))
}
fn check_circuit(&self) -> Result<(), Error> {
let state = self.inner.circuit_state.load(Ordering::Acquire);
match state {
CIRCUIT_STATE_CLOSED => Ok(()),
CIRCUIT_STATE_HALF_OPEN => {
Err(Error::CircuitOpen {
cooldown_secs: self.inner.circuit_cooldown.as_secs(),
})
}
CIRCUIT_STATE_OPEN => {
let opened_at = self
.inner
.circuit_opened_at
.lock()
.unwrap_or_else(|e| e.into_inner());
let elapsed = match *opened_at {
Some(t) => t.elapsed() >= self.inner.circuit_cooldown,
None => true,
};
drop(opened_at);
if elapsed {
if self
.inner
.circuit_state
.compare_exchange(
CIRCUIT_STATE_OPEN,
CIRCUIT_STATE_HALF_OPEN,
Ordering::AcqRel,
Ordering::Acquire,
)
.is_ok()
{
debug!("erebyx-sdk circuit HALF_OPEN — probe call allowed");
return Ok(());
}
}
Err(Error::CircuitOpen {
cooldown_secs: self.inner.circuit_cooldown.as_secs(),
})
}
_ => Ok(()),
}
}
fn record_success(&self) {
self.inner.consecutive_failures.store(0, Ordering::Relaxed);
let prev = self
.inner
.circuit_state
.swap(CIRCUIT_STATE_CLOSED, Ordering::AcqRel);
if prev != CIRCUIT_STATE_CLOSED {
let mut guard = self
.inner
.circuit_opened_at
.lock()
.unwrap_or_else(|e| e.into_inner());
*guard = None;
debug!(
prev_state = prev,
"erebyx-sdk circuit CLOSED — substrate healthy again"
);
}
}
fn record_failure(&self) {
let state = self.inner.circuit_state.load(Ordering::Acquire);
if state == CIRCUIT_STATE_HALF_OPEN {
self.inner
.circuit_state
.store(CIRCUIT_STATE_OPEN, Ordering::Release);
let mut guard = self
.inner
.circuit_opened_at
.lock()
.unwrap_or_else(|e| e.into_inner());
*guard = Some(Instant::now());
warn!("erebyx-sdk circuit RE-OPENED — half-open probe failed");
return;
}
if state == CIRCUIT_STATE_OPEN {
return; }
let failures = self
.inner
.consecutive_failures
.fetch_add(1, Ordering::Relaxed)
+ 1;
if failures >= CIRCUIT_BREAK_THRESHOLD {
if self
.inner
.circuit_state
.compare_exchange(
CIRCUIT_STATE_CLOSED,
CIRCUIT_STATE_OPEN,
Ordering::AcqRel,
Ordering::Acquire,
)
.is_ok()
{
let mut guard = self
.inner
.circuit_opened_at
.lock()
.unwrap_or_else(|e| e.into_inner());
*guard = Some(Instant::now());
warn!(
failures = failures,
"erebyx-sdk circuit OPEN — {CIRCUIT_BREAK_THRESHOLD} consecutive failures"
);
}
}
}
}
pub struct MemoryBuilder {
api_key: String,
api_url: String,
instance_id: String,
session_id: String,
passphrase: Option<String>,
timeout: Duration,
circuit_cooldown: Duration,
}
impl MemoryBuilder {
pub fn api_url(mut self, url: &str) -> Self {
self.api_url = url.to_string();
self
}
pub fn instance_id(mut self, id: &str) -> Self {
self.instance_id = id.to_string();
self
}
pub fn passphrase(mut self, passphrase: &str) -> Self {
let trimmed = passphrase.trim();
if !trimmed.is_empty() {
self.passphrase = Some(trimmed.to_string());
}
self
}
pub fn session_id(mut self, id: &str) -> Self {
if !id.is_empty() {
self.session_id = id.to_string();
}
self
}
pub fn timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
pub fn circuit_cooldown(mut self, cooldown: Duration) -> Self {
self.circuit_cooldown = cooldown;
self
}
pub fn build(self) -> Result<Memory, Error> {
if self.api_key.is_empty() {
return Err(Error::Config("API key cannot be empty".into()));
}
if !is_safe_url(&self.api_url) {
return Err(Error::Config(format!(
"api_url must be https:// (got {}). \
Plain http:// is only allowed for localhost/127.0.0.1.",
self.api_url
)));
}
let client = Client::builder()
.timeout(self.timeout)
.connect_timeout(Duration::from_secs(5))
.build()
.map_err(|e| Error::Config(format!("Failed to create HTTP client: {e}")))?;
Ok(Memory {
inner: Arc::new(MemoryInner {
client,
api_url: self.api_url,
api_key: self.api_key,
instance_id: self.instance_id,
session_id: self.session_id,
passphrase: self.passphrase,
consecutive_failures: AtomicU64::new(0),
circuit_state: AtomicU8::new(CIRCUIT_STATE_CLOSED),
circuit_opened_at: std::sync::Mutex::new(None),
circuit_cooldown: self.circuit_cooldown,
}),
})
}
}