1use std::collections::HashMap;
4use std::net::IpAddr;
5use std::path::PathBuf;
6use std::sync::Arc;
7use std::time::{Instant, SystemTime};
8
9use tokio::sync::{broadcast, Mutex, RwLock};
10use tokio_util::sync::CancellationToken;
11use tuitbot_core::automation::circuit_breaker::CircuitBreaker;
12use tuitbot_core::automation::Runtime;
13use tuitbot_core::automation::WatchtowerLoop;
14use tuitbot_core::config::{
15 effective_config, Config, ConnectorConfig, ContentSourcesConfig, DeploymentMode,
16};
17use tuitbot_core::content::ContentGenerator;
18use tuitbot_core::llm::factory::create_provider;
19use tuitbot_core::storage::accounts::{self, DEFAULT_ACCOUNT_ID};
20use tuitbot_core::storage::DbPool;
21use tuitbot_core::x_api::auth::TokenManager;
22
23use tuitbot_core::error::XApiError;
24use tuitbot_core::x_api::auth;
25
26use crate::ws::AccountWsEvent;
27
28pub struct PendingOAuth {
30 pub code_verifier: String,
32 pub created_at: Instant,
34 pub account_id: String,
36 pub client_id: String,
38}
39
40pub struct AppState {
42 pub db: DbPool,
44 pub config_path: PathBuf,
46 pub data_dir: PathBuf,
48 pub event_tx: broadcast::Sender<AccountWsEvent>,
50 pub api_token: String,
52 pub passphrase_hash: RwLock<Option<String>>,
54 pub passphrase_hash_mtime: RwLock<Option<SystemTime>>,
56 pub bind_host: String,
58 pub bind_port: u16,
60 pub login_attempts: Mutex<HashMap<IpAddr, (u32, Instant)>>,
62 pub runtimes: Mutex<HashMap<String, Runtime>>,
64 pub content_generators: Mutex<HashMap<String, Arc<ContentGenerator>>>,
66 pub circuit_breaker: Option<Arc<CircuitBreaker>>,
68 pub watchtower_cancel: RwLock<Option<CancellationToken>>,
70 pub content_sources: RwLock<ContentSourcesConfig>,
72 pub connector_config: ConnectorConfig,
74 pub deployment_mode: DeploymentMode,
76 pub pending_oauth: Mutex<HashMap<String, PendingOAuth>>,
78 pub token_managers: Mutex<HashMap<String, Arc<TokenManager>>>,
80 pub x_client_id: String,
82}
83
84impl AppState {
85 pub async fn get_x_access_token(
90 &self,
91 token_path: &std::path::Path,
92 account_id: &str,
93 ) -> Result<String, XApiError> {
94 {
96 let managers = self.token_managers.lock().await;
97 if let Some(tm) = managers.get(account_id) {
98 return tm.get_access_token().await;
99 }
100 }
101
102 let tokens = auth::load_tokens(token_path)?.ok_or(XApiError::AuthExpired)?;
104
105 let tm = Arc::new(TokenManager::new(
106 tokens,
107 self.x_client_id.clone(),
108 token_path.to_path_buf(),
109 ));
110
111 let access_token = tm.get_access_token().await?;
112
113 self.token_managers
114 .lock()
115 .await
116 .insert(account_id.to_string(), tm);
117
118 Ok(access_token)
119 }
120
121 pub async fn load_effective_config(&self, account_id: &str) -> Result<Config, String> {
126 let contents = std::fs::read_to_string(&self.config_path).unwrap_or_default();
127 let base: Config = toml::from_str(&contents).unwrap_or_default();
128
129 if account_id == DEFAULT_ACCOUNT_ID {
130 return Ok(base);
131 }
132
133 let account = accounts::get_account(&self.db, account_id)
134 .await
135 .map_err(|e| e.to_string())?
136 .ok_or_else(|| format!("account not found: {account_id}"))?;
137
138 effective_config(&base, &account.config_overrides)
139 .map(|r| r.config)
140 .map_err(|e| e.to_string())
141 }
142
143 pub async fn get_or_create_content_generator(
147 &self,
148 account_id: &str,
149 ) -> Result<Arc<ContentGenerator>, String> {
150 {
152 let generators = self.content_generators.lock().await;
153 if let Some(gen) = generators.get(account_id) {
154 return Ok(gen.clone());
155 }
156 }
157
158 let config = self.load_effective_config(account_id).await?;
159
160 let provider =
161 create_provider(&config.llm).map_err(|e| format!("LLM not configured: {e}"))?;
162
163 let gen = Arc::new(ContentGenerator::new(provider, config.business));
164
165 self.content_generators
166 .lock()
167 .await
168 .insert(account_id.to_string(), gen.clone());
169
170 Ok(gen)
171 }
172
173 pub async fn restart_watchtower(&self) {
179 if let Some(cancel) = self.watchtower_cancel.write().await.take() {
181 cancel.cancel();
182 tracing::info!("Watchtower cancelled for config reload");
183 }
184
185 let loaded_config = Config::load(Some(&self.config_path.to_string_lossy())).ok();
187 let new_sources = loaded_config
188 .as_ref()
189 .map(|c| c.content_sources.clone())
190 .unwrap_or_default();
191 let connector_config = loaded_config
192 .as_ref()
193 .map(|c| c.connectors.clone())
194 .unwrap_or_default();
195 let deployment_mode = loaded_config
196 .as_ref()
197 .map(|c| c.deployment_mode.clone())
198 .unwrap_or_default();
199
200 let has_enabled: Vec<_> = new_sources
202 .sources
203 .iter()
204 .filter(|s| {
205 s.is_enabled()
206 && deployment_mode.allows_source_type(&s.source_type)
207 && (s.path.is_some() || s.folder_id.is_some())
208 })
209 .collect();
210
211 if has_enabled.is_empty() {
212 tracing::info!("Watchtower restart: no enabled sources, not spawning");
213 *self.content_sources.write().await = new_sources;
214 return;
215 }
216
217 let cancel = CancellationToken::new();
219 let watchtower = WatchtowerLoop::new(
220 self.db.clone(),
221 new_sources.clone(),
222 connector_config,
223 self.data_dir.clone(),
224 );
225 let cancel_clone = cancel.clone();
226 tokio::spawn(async move {
227 watchtower.run(cancel_clone).await;
228 });
229
230 tracing::info!(
231 sources = has_enabled.len(),
232 "Watchtower restarted with updated config"
233 );
234
235 *self.watchtower_cancel.write().await = Some(cancel);
237 *self.content_sources.write().await = new_sources;
238 }
239}