Skip to main content

ironflow_api/
state.rs

1//! Application state and dependency injection.
2//!
3//! [`AppState`] holds the shared [`Store`] and [`Engine`] used by all handlers.
4
5use std::sync::Arc;
6#[cfg(feature = "prometheus")]
7use std::sync::OnceLock;
8
9use axum::extract::FromRef;
10#[cfg(feature = "prometheus")]
11use metrics_exporter_prometheus::{PrometheusBuilder, PrometheusHandle};
12use tokio::sync::broadcast;
13use uuid::Uuid;
14
15use ironflow_auth::jwt::JwtConfig;
16use ironflow_engine::engine::Engine;
17use ironflow_engine::notify::Event;
18use ironflow_store::entities::Run;
19use ironflow_store::store::Store;
20
21use crate::error::ApiError;
22
23/// Global application state.
24///
25/// Holds the shared store (runs, users, API keys, secrets) and engine,
26/// extracted by handlers using Axum's state extraction mechanism.
27///
28/// # Examples
29///
30/// ```no_run
31/// use ironflow_api::state::AppState;
32/// use ironflow_auth::jwt::JwtConfig;
33/// use ironflow_store::prelude::*;
34/// use ironflow_store::store::Store;
35/// use ironflow_engine::engine::Engine;
36/// use ironflow_core::providers::claude::ClaudeCodeProvider;
37/// use std::sync::Arc;
38///
39/// # async fn example() {
40/// let store: Arc<dyn Store> = Arc::new(InMemoryStore::new());
41/// let provider = Arc::new(ClaudeCodeProvider::new());
42/// let engine = Arc::new(Engine::new(store.clone(), provider));
43/// let jwt_config = Arc::new(JwtConfig {
44///     secret: "secret".to_string(),
45///     access_token_ttl_secs: 900,
46///     refresh_token_ttl_secs: 604800,
47///     cookie_domain: None,
48///     cookie_secure: false,
49/// });
50/// let broadcaster = ironflow_api::sse::SseBroadcaster::new();
51/// let state = AppState::new(store, engine, jwt_config, "token".to_string(), broadcaster.sender());
52/// # }
53/// ```
54#[derive(Clone)]
55pub struct AppState {
56    /// The unified backing store for runs, steps, users, API keys, and secrets.
57    pub store: Arc<dyn Store>,
58    /// The workflow orchestration engine.
59    pub engine: Arc<Engine>,
60    /// JWT configuration for auth tokens.
61    pub jwt_config: Arc<JwtConfig>,
62    /// Static token for worker-to-API authentication.
63    pub worker_token: String,
64    /// Broadcast sender for SSE event streaming.
65    pub event_sender: broadcast::Sender<Event>,
66    /// Prometheus metrics handle (only when `prometheus` feature is enabled).
67    #[cfg(feature = "prometheus")]
68    pub prometheus_handle: PrometheusHandle,
69}
70
71impl FromRef<AppState> for Arc<dyn Store> {
72    fn from_ref(state: &AppState) -> Self {
73        Arc::clone(&state.store)
74    }
75}
76
77impl FromRef<AppState> for Arc<JwtConfig> {
78    fn from_ref(state: &AppState) -> Self {
79        Arc::clone(&state.jwt_config)
80    }
81}
82
83#[cfg(feature = "prometheus")]
84impl FromRef<AppState> for PrometheusHandle {
85    fn from_ref(state: &AppState) -> Self {
86        state.prometheus_handle.clone()
87    }
88}
89
90impl AppState {
91    /// Create a new `AppState`.
92    ///
93    /// When the `prometheus` feature is enabled, a global Prometheus recorder
94    /// is installed (once) and its handle is stored in the state.
95    ///
96    /// # Panics
97    ///
98    /// Panics if a Prometheus recorder cannot be installed (should only
99    /// happen if another incompatible recorder was set elsewhere).
100    pub fn new(
101        store: Arc<dyn Store>,
102        engine: Arc<Engine>,
103        jwt_config: Arc<JwtConfig>,
104        worker_token: String,
105        event_sender: broadcast::Sender<Event>,
106    ) -> Self {
107        Self {
108            store,
109            engine,
110            jwt_config,
111            worker_token,
112            event_sender,
113            #[cfg(feature = "prometheus")]
114            prometheus_handle: Self::global_prometheus_handle(),
115        }
116    }
117
118    /// Install (or reuse) a global Prometheus recorder and return its handle.
119    #[cfg(feature = "prometheus")]
120    fn global_prometheus_handle() -> PrometheusHandle {
121        static HANDLE: OnceLock<PrometheusHandle> = OnceLock::new();
122        HANDLE
123            .get_or_init(|| {
124                PrometheusBuilder::new()
125                    .install_recorder()
126                    .expect("failed to install Prometheus recorder")
127            })
128            .clone()
129    }
130
131    /// Fetch a run by ID or return 404.
132    ///
133    /// # Errors
134    ///
135    /// Returns `ApiError::RunNotFound` if the run does not exist.
136    /// Returns `ApiError::Store` if there is a store error.
137    pub async fn get_run_or_404(&self, id: Uuid) -> Result<Run, ApiError> {
138        self.store
139            .get_run(id)
140            .await
141            .map_err(ApiError::from)?
142            .ok_or(ApiError::RunNotFound(id))
143    }
144}
145
146#[cfg(test)]
147mod tests {
148    use super::*;
149    use ironflow_core::providers::claude::ClaudeCodeProvider;
150    use ironflow_store::memory::InMemoryStore;
151    use ironflow_store::store::Store;
152
153    fn test_state() -> AppState {
154        let store: Arc<dyn Store> = Arc::new(InMemoryStore::new());
155        let provider = Arc::new(ClaudeCodeProvider::new());
156        let engine = Arc::new(Engine::new(store.clone(), provider));
157        let jwt_config = Arc::new(JwtConfig {
158            secret: "test-secret".to_string(),
159            access_token_ttl_secs: 900,
160            refresh_token_ttl_secs: 604800,
161            cookie_domain: None,
162            cookie_secure: false,
163        });
164        let (event_sender, _) = broadcast::channel::<Event>(1);
165        AppState::new(
166            store,
167            engine,
168            jwt_config,
169            "test-worker-token".to_string(),
170            event_sender,
171        )
172    }
173
174    #[test]
175    fn app_state_cloneable() {
176        let state = test_state();
177        let _cloned = state.clone();
178    }
179
180    #[test]
181    fn app_state_from_ref() {
182        let state = test_state();
183        let extracted: Arc<dyn Store> = Arc::from_ref(&state);
184        assert!(Arc::ptr_eq(&extracted, &state.store));
185    }
186}