1use std::{
9 collections::HashMap,
10 net::{IpAddr, SocketAddr},
11 sync::{Arc, Mutex},
12 time::{Duration, Instant},
13};
14
15use anyhow::Context as _;
16use axum::{
17 Json, Router,
18 extract::{
19 ConnectInfo, Extension, FromRequestParts, Path, Query, WebSocketUpgrade,
20 ws::{Message, WebSocket},
21 },
22 http::{HeaderMap, StatusCode, request::Parts},
23 response::IntoResponse,
24 routing::{get, post},
25};
26use serde::{Deserialize, Serialize};
27use tracing::{debug, error, warn};
28use uuid::Uuid;
29
30use super::{
31 WebChatProvider,
32 auth::{self as jwt, Claims, TenantClaims},
33 bus::{NoopBus, SharedBus},
34 conversation::{
35 Activity, ChannelAccount, ConversationAccount, SharedConversationStore, StoredActivity,
36 memory_store,
37 },
38 ingress,
39 session::{MemorySessionStore, SharedSessionStore, WebchatSession},
40};
41use greentic_types::{EnvId, TeamId, TenantCtx, TenantId};
42
43const TOKEN_TTL_SECONDS: u64 = 1_800;
44const RATE_LIMIT_CAPACITY: usize = 5;
45const RATE_LIMIT_WINDOW: Duration = Duration::from_secs(60);
46
47struct RemoteIp(Option<IpAddr>);
48
49impl<S> FromRequestParts<S> for RemoteIp
50where
51 S: Send + Sync,
52{
53 type Rejection = std::convert::Infallible;
54
55 fn from_request_parts(
56 parts: &mut Parts,
57 _state: &S,
58 ) -> impl std::future::Future<Output = Result<Self, Self::Rejection>> + Send {
59 let ip = parts
60 .extensions
61 .get::<ConnectInfo<SocketAddr>>()
62 .map(|ConnectInfo(addr)| addr.ip());
63 std::future::ready(Ok(RemoteIp(ip)))
64 }
65}
66
67#[derive(Clone)]
69pub struct StandaloneState {
70 pub provider: WebChatProvider,
71 pub conversations: SharedConversationStore,
72 pub sessions: SharedSessionStore,
73 pub bus: SharedBus,
74 rate_limiter: Arc<IpRateLimiter>,
75}
76
77impl StandaloneState {
78 pub async fn with_store(
80 provider: WebChatProvider,
81 conversations: SharedConversationStore,
82 sessions: SharedSessionStore,
83 bus: SharedBus,
84 ) -> anyhow::Result<Self> {
85 let signing_keys = provider
86 .signing_keys()
87 .await
88 .context("failed to resolve Direct Line signing key")?;
89 if let Err(err) = jwt::install_keys(signing_keys) {
90 debug!("jwt keys already installed: {err}");
91 }
92 Ok(Self {
93 provider,
94 conversations,
95 sessions,
96 bus,
97 rate_limiter: Arc::new(IpRateLimiter::new(RATE_LIMIT_CAPACITY, RATE_LIMIT_WINDOW)),
98 })
99 }
100
101 pub async fn new(provider: WebChatProvider) -> anyhow::Result<Self> {
103 Self::with_store(
104 provider,
105 memory_store(),
106 Arc::new(MemorySessionStore::default()),
107 Arc::new(NoopBus),
108 )
109 .await
110 }
111}
112
113fn router_blueprint() -> Router {
114 Router::new()
115 .route(
116 "/v3/directline/tokens/generate",
117 post(generate_token_handler),
118 )
119 .route(
120 "/v3/directline/conversations",
121 post(start_conversation_handler),
122 )
123 .route(
124 "/v3/directline/conversations/{id}/activities",
125 get(list_activities_handler).post(post_activity_handler),
126 )
127 .route(
128 "/v3/directline/conversations/{id}/stream",
129 get(conversation_stream_handler),
130 )
131 .route(
132 "/webchat/admin/{env}/{tenant}/post-activity",
133 post(admin_post_activity_handler),
134 )
135}
136
137pub fn router(state: Arc<StandaloneState>) -> Router {
139 router_blueprint().layer(axum::Extension(state))
140}
141
142#[allow(dead_code)]
143const fn assert_state_bounds<T: Send + Sync + Clone>() {}
144const _: () = {
145 assert_state_bounds::<StandaloneState>();
146};
147
148#[doc(hidden)]
149#[derive(Debug, Deserialize)]
150pub struct TenantQuery {
151 env: String,
152 tenant: String,
153 #[serde(default)]
154 team: Option<String>,
155}
156
157#[doc(hidden)]
158#[derive(Debug, Deserialize)]
159#[serde(rename_all = "camelCase")]
160pub struct GenerateTokenRequest {
161 #[serde(default)]
162 user: Option<UserDescriptor>,
163 #[allow(dead_code)]
164 #[serde(default)]
165 trusted_origins: Option<Vec<String>>,
166}
167
168#[doc(hidden)]
169#[derive(Debug, Deserialize)]
170pub struct UserDescriptor {
171 #[serde(default)]
172 id: Option<String>,
173}
174
175#[derive(Debug, Serialize, Deserialize)]
176struct TokenResponse {
177 token: String,
178 expires_in: u64,
179}
180
181async fn generate_token_handler(
182 Extension(state): Extension<Arc<StandaloneState>>,
183 remote: RemoteIp,
184 Query(query): Query<TenantQuery>,
185 Json(body): Json<GenerateTokenRequest>,
186) -> Result<Json<TokenResponse>, StatusCode> {
187 let ip = remote.0.unwrap_or_else(|| IpAddr::from([127, 0, 0, 1]));
188 if !state.rate_limiter.check(ip) {
189 return Err(StatusCode::TOO_MANY_REQUESTS);
190 }
191
192 let tenant_ctx = tenant_ctx_from_query(&query)?;
193 let subject = body
194 .user
195 .and_then(|user| user.id)
196 .filter(|id| !id.trim().is_empty())
197 .unwrap_or_else(|| format!("user:{}", Uuid::new_v4()));
198
199 let claims = Claims::new(
200 subject,
201 tenant_claims_from_ctx(&tenant_ctx),
202 jwt::ttl(TOKEN_TTL_SECONDS),
203 );
204 let token = jwt::sign(&claims).map_err(|err| map_error("sign token", err))?;
205 Ok(Json(TokenResponse {
206 token,
207 expires_in: TOKEN_TTL_SECONDS,
208 }))
209}
210
211#[doc(hidden)]
212#[derive(Debug, Deserialize)]
213pub struct ConversationPath {
214 id: String,
215}
216
217#[derive(Debug, Serialize, Deserialize)]
218#[serde(rename_all = "camelCase")]
219struct ConversationResponse {
220 conversation_id: String,
221 token: String,
222 #[serde(skip_serializing_if = "Option::is_none")]
223 stream_url: Option<String>,
224}
225
226async fn start_conversation_handler(
227 Extension(state): Extension<Arc<StandaloneState>>,
228 headers: HeaderMap,
229) -> Result<Json<ConversationResponse>, StatusCode> {
230 let token = extract_bearer(&headers)?;
231 let claims = jwt::verify(token).map_err(|_| StatusCode::UNAUTHORIZED)?;
232 if claims.conv.is_some() {
233 return Err(StatusCode::BAD_REQUEST);
234 }
235 let tenant_ctx = tenant_ctx_from_claims(&claims)?;
236 let conversation_id = Uuid::new_v4().to_string();
237 state
238 .conversations
239 .create(&conversation_id, tenant_ctx.clone())
240 .await
241 .map_err(map_store_error)?;
242
243 let conversation_token = Claims::new(
244 claims.sub.clone(),
245 claims.ctx.clone(),
246 jwt::ttl(TOKEN_TTL_SECONDS),
247 )
248 .with_conversation(conversation_id.clone());
249 let token = jwt::sign(&conversation_token).map_err(|err| map_error("sign token", err))?;
250
251 if let Err(err) = state
252 .sessions
253 .upsert(WebchatSession::new(
254 conversation_id.clone(),
255 tenant_ctx.clone(),
256 token.clone(),
257 ))
258 .await
259 {
260 error!(error = %err, "failed to persist webchat session");
261 return Err(StatusCode::INTERNAL_SERVER_ERROR);
262 }
263
264 let stream_url = stream_url_from_headers(&state, &headers, &conversation_id, &token);
265 Ok(Json(ConversationResponse {
266 conversation_id,
267 token,
268 stream_url,
269 }))
270}
271
272#[doc(hidden)]
273#[derive(Debug, Deserialize)]
274pub struct ActivitiesQuery {
275 #[serde(default)]
276 watermark: Option<String>,
277}
278
279#[derive(Debug, Serialize)]
280struct ActivitiesResponse {
281 activities: Vec<Activity>,
282 watermark: String,
283}
284
285async fn list_activities_handler(
286 Extension(state): Extension<Arc<StandaloneState>>,
287 Path(path): Path<ConversationPath>,
288 headers: HeaderMap,
289 Query(query): Query<ActivitiesQuery>,
290) -> Result<Json<ActivitiesResponse>, StatusCode> {
291 let _claims = validate_conversation_token(state.as_ref(), &path.id, &headers).await?;
292 let watermark = query
293 .watermark
294 .as_deref()
295 .and_then(|value| {
296 let trimmed = value.trim();
297 if trimmed.is_empty() {
298 None
299 } else {
300 Some(trimmed)
301 }
302 })
303 .map(parse_watermark)
304 .transpose()?;
305 let page = state
306 .conversations
307 .activities(&path.id, watermark)
308 .await
309 .map_err(map_store_error)?;
310 let activities: Vec<Activity> = page
311 .activities
312 .into_iter()
313 .map(|entry| entry.activity)
314 .collect();
315 Ok(Json(ActivitiesResponse {
316 activities,
317 watermark: page.watermark.to_string(),
318 }))
319}
320
321#[derive(Debug, Serialize)]
322struct ActivityAck {
323 id: String,
324}
325
326async fn post_activity_handler(
327 Extension(state): Extension<Arc<StandaloneState>>,
328 Path(path): Path<ConversationPath>,
329 headers: HeaderMap,
330 Json(payload): Json<serde_json::Value>,
331) -> Result<impl IntoResponse, StatusCode> {
332 let claims = validate_conversation_token(state.as_ref(), &path.id, &headers).await?;
333 let tenant_ctx = state
334 .conversations
335 .tenant_ctx(&path.id)
336 .await
337 .map_err(map_store_error)?;
338
339 let mut bus_activity = payload.clone();
340 apply_user_json_defaults(&mut bus_activity, &path.id, &claims.sub);
341
342 let mut activity: Activity =
343 serde_json::from_value(payload).map_err(|_| StatusCode::BAD_REQUEST)?;
344 apply_user_defaults(&mut activity, &path.id, &claims.sub);
345
346 let ingress = ingress::Ingress::new(state.bus.clone(), state.sessions.clone());
347 if let Err(err) = ingress
348 .publish_incoming(&bus_activity, &tenant_ctx, &path.id)
349 .await
350 {
351 warn!(error = %err, "failed to publish incoming activity");
352 }
353
354 let stored = state
355 .conversations
356 .append(&path.id, activity)
357 .await
358 .map_err(map_store_error)?;
359
360 if let Err(err) = state
361 .sessions
362 .update_watermark(&path.id, Some((stored.watermark + 1).to_string()))
363 .await
364 {
365 warn!(error = %err, "failed to update watermark");
366 }
367 Ok((
368 StatusCode::CREATED,
369 Json(ActivityAck {
370 id: stored.activity.id,
371 }),
372 ))
373}
374
375#[derive(Debug, Deserialize)]
376pub struct AdminPath {
377 pub env: String,
378 pub tenant: String,
379}
380
381#[derive(Debug, Deserialize)]
382pub struct AdminPostActivityRequest {
383 #[serde(default)]
384 pub team: Option<String>,
385 #[serde(rename = "conversation_id", default)]
386 pub conversation_id: Option<String>,
387 pub activity: serde_json::Value,
388}
389
390#[derive(Debug, Serialize)]
391pub struct AdminPostActivityResponse {
392 pub posted: usize,
393 pub skipped: usize,
394}
395
396async fn admin_post_activity_handler(
397 Extension(state): Extension<Arc<StandaloneState>>,
398 Path(path): Path<AdminPath>,
399 Json(body): Json<AdminPostActivityRequest>,
400) -> Result<Json<AdminPostActivityResponse>, StatusCode> {
401 let AdminPostActivityRequest {
402 team,
403 conversation_id,
404 activity,
405 } = body;
406
407 if !activity.is_object() {
408 return Err(StatusCode::BAD_REQUEST);
409 }
410
411 let base_activity: Activity =
412 serde_json::from_value(activity).map_err(|_| StatusCode::BAD_REQUEST)?;
413 let team_filter = team.as_deref();
414
415 if let Some(conversation_id) = conversation_id {
416 let session = state
417 .sessions
418 .get(&conversation_id)
419 .await
420 .map_err(|err| map_error("load session", err))?
421 .ok_or(StatusCode::NOT_FOUND)?;
422
423 if !session
424 .tenant_ctx
425 .env
426 .as_ref()
427 .eq_ignore_ascii_case(&path.env)
428 || !session
429 .tenant_ctx
430 .tenant
431 .as_ref()
432 .eq_ignore_ascii_case(&path.tenant)
433 {
434 return Err(StatusCode::NOT_FOUND);
435 }
436
437 if let Some(team) = team_filter
438 && !session
439 .tenant_ctx
440 .team
441 .as_ref()
442 .map(|value| value.as_ref().eq_ignore_ascii_case(team))
443 .unwrap_or(false)
444 {
445 return Err(StatusCode::NOT_FOUND);
446 }
447
448 if !session.proactive_ok {
449 return Err(StatusCode::BAD_REQUEST);
450 }
451
452 append_bot_activity(state.as_ref(), &conversation_id, &base_activity).await?;
453
454 return Ok(Json(AdminPostActivityResponse {
455 posted: 1,
456 skipped: 0,
457 }));
458 }
459
460 let sessions = state
461 .sessions
462 .list_by_tenant(&path.env, &path.tenant, team_filter)
463 .await
464 .map_err(|err| map_error("list sessions", err))?;
465
466 if sessions.is_empty() {
467 return Err(StatusCode::NOT_FOUND);
468 }
469
470 let mut posted = 0usize;
471 let mut skipped = 0usize;
472 for session in sessions {
473 if !session.proactive_ok {
474 skipped += 1;
475 continue;
476 }
477
478 match append_bot_activity(state.as_ref(), &session.conversation_id, &base_activity).await {
479 Ok(()) => posted += 1,
480 Err(StatusCode::NOT_FOUND) => {
481 skipped += 1;
482 warn!(
483 conversation = %session.conversation_id,
484 "conversation not found while appending activity"
485 );
486 }
487 Err(code) => return Err(code),
488 }
489 }
490
491 if posted == 0 {
492 return Err(StatusCode::NOT_FOUND);
493 }
494
495 Ok(Json(AdminPostActivityResponse { posted, skipped }))
496}
497
498#[derive(Debug, Deserialize)]
499struct StreamQuery {
500 t: String,
501 #[serde(default)]
502 watermark: Option<String>,
503}
504
505async fn conversation_stream_handler(
506 Extension(state): Extension<Arc<StandaloneState>>,
507 Path(path): Path<ConversationPath>,
508 Query(query): Query<StreamQuery>,
509 ws: WebSocketUpgrade,
510) -> Result<impl IntoResponse, StatusCode> {
511 let claims = jwt::verify(&query.t).map_err(|_| StatusCode::UNAUTHORIZED)?;
512 ensure_conversation_access(state.as_ref(), &path.id, &claims).await?;
513 let watermark = query
514 .watermark
515 .as_deref()
516 .map(parse_watermark)
517 .transpose()?;
518 Ok(ws.on_upgrade(move |socket| async move {
519 if let Err(err) = run_websocket(socket, state, path.id, watermark).await {
520 warn!("websocket closed with error: {err:?}");
521 }
522 }))
523}
524
525async fn run_websocket(
526 mut socket: WebSocket,
527 state: Arc<StandaloneState>,
528 conversation_id: String,
529 watermark: Option<u64>,
530) -> anyhow::Result<()> {
531 let page = state
532 .conversations
533 .activities(&conversation_id, watermark)
534 .await
535 .map_err(|err| anyhow::anyhow!(err.to_string()))?;
536 let initial: Vec<StoredActivity> = page.activities.iter().cloned().collect();
537 if !initial.is_empty() {
538 send_envelope(&mut socket, &initial, page.watermark).await?;
539 }
540 let mut consecutive_failures: u32 = 0;
541 const SEND_FAILURE_THRESHOLD: u32 = 5;
542 let mut current_watermark = page.watermark;
543 let mut subscriber = state
544 .conversations
545 .subscribe(&conversation_id)
546 .await
547 .map_err(|err| anyhow::anyhow!(err.to_string()))?;
548
549 loop {
550 tokio::select! {
551 message = socket.recv() => {
552 match message {
553 Some(Ok(Message::Close(_))) | None => break,
554 Some(Ok(_)) => continue,
555 Some(Err(err)) => {
556 warn!("websocket recv error: {err}");
557 break;
558 }
559 }
560 }
561 received = subscriber.recv() => {
562 match received {
563 Ok(activity) => {
564 if let Err(err) = send_envelope(&mut socket, std::slice::from_ref(&activity), activity.watermark + 1).await {
565 consecutive_failures = consecutive_failures.saturating_add(1);
566 warn!(error = ?err, consecutive_failures, "websocket send error");
567 if consecutive_failures >= SEND_FAILURE_THRESHOLD {
568 warn!("terminating websocket due to repeated send failures");
569 break;
570 }
571 continue;
572 } else {
573 consecutive_failures = 0;
574 }
575 current_watermark = activity.watermark + 1;
576 }
577 Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
578 Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => {
579 let page = state.conversations.activities(&conversation_id, Some(current_watermark))
580 .await
581 .map_err(|err| anyhow::anyhow!(err.to_string()))?;
582 let resend: Vec<StoredActivity> = page.activities.iter().cloned().collect();
583 if !resend.is_empty() {
584 if let Err(err) = send_envelope(&mut socket, &resend, page.watermark).await {
585 consecutive_failures = consecutive_failures.saturating_add(1);
586 warn!(error = ?err, consecutive_failures, "websocket resend error");
587 if consecutive_failures >= SEND_FAILURE_THRESHOLD {
588 warn!("terminating websocket due to repeated resend failures");
589 break;
590 }
591 continue;
592 } else {
593 consecutive_failures = 0;
594 }
595 }
596 current_watermark = page.watermark;
597 }
598 }
599 }
600 }
601 }
602 Ok(())
603}
604
605#[doc(hidden)]
606pub async fn test_run_websocket_public(
607 socket: WebSocket,
608 state: Arc<StandaloneState>,
609 conversation_id: String,
610 watermark: Option<u64>,
611) -> anyhow::Result<()> {
612 run_websocket(socket, state, conversation_id, watermark).await
613}
614
615async fn send_envelope(
616 socket: &mut WebSocket,
617 activities: &[StoredActivity],
618 watermark: u64,
619) -> anyhow::Result<()> {
620 if activities.is_empty() {
621 return Ok(());
622 }
623 let payload = envelope_payload(activities, watermark)?;
624 socket.send(Message::Text(payload.into())).await?;
625 Ok(())
626}
627
628fn envelope_payload(activities: &[StoredActivity], watermark: u64) -> anyhow::Result<String> {
629 let acts: Vec<Activity> = activities
630 .iter()
631 .map(|entry| entry.activity.clone())
632 .collect();
633 Ok(serde_json::to_string(&serde_json::json!({
634 "activities": acts,
635 "watermark": watermark.to_string(),
636 }))?)
637}
638
639fn tenant_ctx_from_query(query: &TenantQuery) -> Result<TenantCtx, StatusCode> {
640 let env = query.env.trim();
641 let tenant = query.tenant.trim();
642 if env.is_empty() || tenant.is_empty() {
643 return Err(StatusCode::BAD_REQUEST);
644 }
645 let mut ctx = TenantCtx::new(EnvId(env.to_string()), TenantId(tenant.to_string()));
646 if let Some(team) = &query.team {
647 let team = team.trim();
648 if !team.is_empty() {
649 ctx = ctx.with_team(Some(TeamId(team.to_string())));
650 }
651 }
652 Ok(ctx)
653}
654
655fn tenant_ctx_from_claims(claims: &Claims) -> Result<TenantCtx, StatusCode> {
656 let mut ctx = TenantCtx::new(
657 EnvId(claims.ctx.env.clone()),
658 TenantId(claims.ctx.tenant.clone()),
659 );
660 if let Some(team) = &claims.ctx.team {
661 ctx = ctx.with_team(Some(TeamId(team.clone())));
662 }
663 Ok(ctx)
664}
665
666fn tenant_claims_from_ctx(ctx: &TenantCtx) -> TenantClaims {
667 TenantClaims {
668 env: ctx.env.as_ref().to_string(),
669 tenant: ctx.tenant.as_ref().to_string(),
670 team: ctx.team.as_ref().map(|team| team.as_ref().to_string()),
671 }
672}
673
674fn extract_bearer(headers: &HeaderMap) -> Result<&str, StatusCode> {
675 let value = headers
676 .get(axum::http::header::AUTHORIZATION)
677 .and_then(|value| value.to_str().ok())
678 .ok_or(StatusCode::UNAUTHORIZED)?;
679 if let Some(rest) = value.strip_prefix("Bearer ") {
680 Ok(rest.trim())
681 } else if let Some(rest) = value.strip_prefix("bearer ") {
682 Ok(rest.trim())
683 } else {
684 Err(StatusCode::UNAUTHORIZED)
685 }
686}
687
688fn parse_watermark(value: &str) -> Result<u64, StatusCode> {
689 value.parse::<u64>().map_err(|_| StatusCode::BAD_REQUEST)
690}
691
692fn stream_url_from_headers(
693 state: &StandaloneState,
694 headers: &HeaderMap,
695 conversation_id: &str,
696 token: &str,
697) -> Option<String> {
698 if let Some(host) = headers
699 .get(axum::http::header::HOST)
700 .and_then(|value| value.to_str().ok())
701 {
702 let scheme = if state
703 .provider
704 .config()
705 .direct_line_base()
706 .starts_with("https://")
707 {
708 "wss"
709 } else {
710 "ws"
711 };
712 Some(format!(
713 "{scheme}://{host}/v3/directline/conversations/{conversation_id}/stream?t={token}"
714 ))
715 } else {
716 None
717 }
718}
719
720async fn validate_conversation_token(
721 state: &StandaloneState,
722 conversation_id: &str,
723 headers: &HeaderMap,
724) -> Result<Claims, StatusCode> {
725 let token = extract_bearer(headers)?;
726 let claims = jwt::verify(token).map_err(|_| StatusCode::UNAUTHORIZED)?;
727 ensure_conversation_access(state, conversation_id, &claims).await?;
728 Ok(claims)
729}
730
731async fn ensure_conversation_access(
732 state: &StandaloneState,
733 conversation_id: &str,
734 claims: &Claims,
735) -> Result<(), StatusCode> {
736 if !claims.has_conversation(conversation_id) {
737 return Err(StatusCode::FORBIDDEN);
738 }
739 let claimed_ctx = tenant_ctx_from_claims(claims)?;
740 let stored_ctx = state
741 .conversations
742 .tenant_ctx(conversation_id)
743 .await
744 .map_err(map_store_error)?;
745 if stored_ctx != claimed_ctx {
746 return Err(StatusCode::FORBIDDEN);
747 }
748 Ok(())
749}
750
751fn normalise_activity(activity: &mut Activity, conversation_id: &str, subject: &str) {
752 if activity
753 .from
754 .as_ref()
755 .map(|from| from.id.trim().is_empty())
756 .unwrap_or(true)
757 {
758 activity.from = Some(ChannelAccount {
759 id: subject.to_string(),
760 name: None,
761 role: Some("user".into()),
762 });
763 }
764 if activity.conversation.is_none() {
765 activity.conversation = Some(ConversationAccount {
766 id: conversation_id.to_string(),
767 });
768 }
769}
770
771fn apply_bot_defaults(activity: &mut Activity, conversation_id: &str) {
772 normalise_activity(activity, conversation_id, "bot");
773 if let Some(from) = activity.from.as_mut() {
774 if from.id.trim().is_empty() {
775 from.id = "bot".into();
776 }
777 from.role = Some("bot".into());
778 } else {
779 activity.from = Some(ChannelAccount {
780 id: "bot".into(),
781 name: None,
782 role: Some("bot".into()),
783 });
784 }
785}
786
787fn apply_user_json_defaults(
788 activity: &mut serde_json::Value,
789 conversation_id: &str,
790 subject: &str,
791) {
792 let Some(obj) = activity.as_object_mut() else {
793 return;
794 };
795
796 let from_entry = obj
797 .entry("from".to_string())
798 .or_insert_with(|| serde_json::Value::Object(serde_json::Map::new()));
799 if !from_entry.is_object() {
800 *from_entry = serde_json::Value::Object(serde_json::Map::new());
801 }
802 if let Some(from_obj) = from_entry.as_object_mut() {
803 let id_entry = from_obj
804 .entry("id".to_string())
805 .or_insert_with(|| serde_json::Value::String(subject.to_string()));
806 if id_entry
807 .as_str()
808 .map(|value| value.trim().is_empty())
809 .unwrap_or(true)
810 {
811 *id_entry = serde_json::Value::String(subject.to_string());
812 }
813 let role_entry = from_obj
814 .entry("role".to_string())
815 .or_insert_with(|| serde_json::Value::String("user".into()));
816 if !role_entry
817 .as_str()
818 .map(|role| role.eq_ignore_ascii_case("user"))
819 .unwrap_or(false)
820 {
821 *role_entry = serde_json::Value::String("user".into());
822 }
823 }
824
825 let conversation_entry = obj.entry("conversation".to_string()).or_insert_with(|| {
826 let mut conv = serde_json::Map::new();
827 conv.insert(
828 "id".to_string(),
829 serde_json::Value::String(conversation_id.to_string()),
830 );
831 serde_json::Value::Object(conv)
832 });
833 if let Some(conv_obj) = conversation_entry.as_object_mut() {
834 let id_entry = conv_obj
835 .entry("id".to_string())
836 .or_insert_with(|| serde_json::Value::String(conversation_id.to_string()));
837 if id_entry
838 .as_str()
839 .map(|value| value.trim().is_empty())
840 .unwrap_or(true)
841 {
842 *id_entry = serde_json::Value::String(conversation_id.to_string());
843 }
844 } else {
845 let mut conv = serde_json::Map::new();
846 conv.insert(
847 "id".to_string(),
848 serde_json::Value::String(conversation_id.to_string()),
849 );
850 *conversation_entry = serde_json::Value::Object(conv);
851 }
852}
853
854fn apply_user_defaults(activity: &mut Activity, conversation_id: &str, subject: &str) {
855 normalise_activity(activity, conversation_id, subject);
856 if let Some(from) = activity.from.as_mut() {
857 if from
858 .role
859 .as_deref()
860 .map(|role| role.eq_ignore_ascii_case("user"))
861 .unwrap_or(true)
862 {
863 from.role = Some("user".into());
864 }
865 if from.id.trim().is_empty() {
866 from.id = subject.to_string();
867 }
868 } else {
869 activity.from = Some(ChannelAccount {
870 id: subject.to_string(),
871 name: None,
872 role: Some("user".into()),
873 });
874 }
875}
876
877async fn append_bot_activity(
878 state: &StandaloneState,
879 conversation_id: &str,
880 base_activity: &Activity,
881) -> Result<(), StatusCode> {
882 let mut activity = base_activity.clone();
883 apply_bot_defaults(&mut activity, conversation_id);
884 let stored = state
885 .conversations
886 .append(conversation_id, activity)
887 .await
888 .map_err(map_store_error)?;
889
890 if let Err(err) = state
891 .sessions
892 .update_watermark(conversation_id, Some((stored.watermark + 1).to_string()))
893 .await
894 {
895 warn!(
896 error = %err,
897 conversation = %conversation_id,
898 "failed to update watermark"
899 );
900 }
901
902 Ok(())
903}
904
905fn map_store_error(err: super::conversation::StoreError) -> StatusCode {
906 match err {
907 super::conversation::StoreError::AlreadyExists(_) => StatusCode::CONFLICT,
908 super::conversation::StoreError::NotFound(_) => StatusCode::NOT_FOUND,
909 super::conversation::StoreError::QuotaExceeded(_) => StatusCode::TOO_MANY_REQUESTS,
910 super::conversation::StoreError::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
911 }
912}
913
914fn map_error(context: &str, err: anyhow::Error) -> StatusCode {
915 error!("{context}: {err}");
916 StatusCode::INTERNAL_SERVER_ERROR
917}
918
919struct IpRateLimiter {
920 capacity: f64,
921 refill_per_sec: f64,
922 buckets: Mutex<HashMap<IpAddr, TokenBucket>>,
923}
924
925struct TokenBucket {
926 tokens: f64,
927 last: Instant,
928}
929
930impl IpRateLimiter {
931 fn new(capacity: usize, window: Duration) -> Self {
932 let capacity = capacity as f64;
933 let refill_per_sec = if window.is_zero() {
934 capacity
935 } else {
936 capacity / window.as_secs_f64()
937 };
938 Self {
939 capacity,
940 refill_per_sec,
941 buckets: Mutex::new(HashMap::new()),
942 }
943 }
944
945 fn check(&self, ip: IpAddr) -> bool {
946 let mut guard = self.buckets.lock().expect("rate limiter mutex poisoned");
947 let entry = guard.entry(ip).or_insert(TokenBucket {
948 tokens: self.capacity,
949 last: Instant::now(),
950 });
951 let now = Instant::now();
952 let elapsed = now.saturating_duration_since(entry.last).as_secs_f64();
953 if elapsed > 0.0 {
954 entry.tokens = (entry.tokens + elapsed * self.refill_per_sec).min(self.capacity);
955 entry.last = now;
956 }
957 if entry.tokens >= 1.0 {
958 entry.tokens -= 1.0;
959 true
960 } else {
961 false
962 }
963 }
964}
965
966#[cfg(test)]
967mod tests {
968 use super::*;
969
970 use crate::platforms::webchat::{
971 WebChatProvider,
972 bus::{EventBus, Subject},
973 config::Config,
974 session::WebchatSessionStore,
975 types::{GreenticEvent, MessagePayload},
976 };
977 use axum::{
978 Extension, Json,
979 extract::{Path, Query},
980 http::{self, HeaderMap, HeaderValue, StatusCode},
981 response::IntoResponse,
982 };
983 use greentic_secrets_spec::record_from_plain;
984 use secrets_core::{Scope, SecretUri, SecretsBackend, VersionedSecret};
985 use serde_json::json;
986 use tokio::sync::Mutex;
987
988 #[derive(Clone)]
989 struct TestSecretsBackend {
990 secret: String,
991 }
992
993 impl TestSecretsBackend {
994 fn new(secret: String) -> Self {
995 Self { secret }
996 }
997 }
998
999 impl SecretsBackend for TestSecretsBackend {
1000 fn put(
1001 &self,
1002 _record: secrets_core::SecretRecord,
1003 ) -> secrets_core::Result<secrets_core::SecretVersion> {
1004 Err(secrets_core::Error::Backend(
1005 "test backend is read-only".into(),
1006 ))
1007 }
1008
1009 fn get(
1010 &self,
1011 uri: &SecretUri,
1012 _version: Option<u64>,
1013 ) -> secrets_core::Result<Option<VersionedSecret>> {
1014 if uri.category() == "webchat" && uri.name() == "jwt_signing_key" {
1015 let record = record_from_plain(self.secret.clone());
1016 Ok(Some(VersionedSecret {
1017 version: 1,
1018 deleted: false,
1019 record: Some(record),
1020 }))
1021 } else {
1022 Ok(None)
1023 }
1024 }
1025
1026 fn list(
1027 &self,
1028 _scope: &secrets_core::Scope,
1029 _category_prefix: Option<&str>,
1030 _name_prefix: Option<&str>,
1031 ) -> secrets_core::Result<Vec<secrets_core::SecretListItem>> {
1032 let uri = SecretUri::new(
1033 Scope::new("global", "webchat", None).expect("scope"),
1034 "webchat".to_string(),
1035 "jwt_signing_key".to_string(),
1036 )
1037 .expect("uri");
1038 Ok(vec![secrets_core::SecretListItem {
1039 uri,
1040 visibility: secrets_core::Visibility::Tenant,
1041 latest_version: Some("1".into()),
1042 content_type: secrets_core::ContentType::Opaque,
1043 }])
1044 }
1045
1046 fn delete(&self, _uri: &SecretUri) -> secrets_core::Result<secrets_core::SecretVersion> {
1047 Err(secrets_core::Error::Backend(
1048 "test backend is read-only".into(),
1049 ))
1050 }
1051
1052 fn versions(
1053 &self,
1054 _uri: &SecretUri,
1055 ) -> secrets_core::Result<Vec<secrets_core::SecretVersion>> {
1056 Ok(vec![secrets_core::SecretVersion {
1057 version: 1,
1058 deleted: false,
1059 }])
1060 }
1061
1062 fn exists(&self, uri: &SecretUri) -> secrets_core::Result<bool> {
1063 Ok(uri.category() == "webchat" && uri.name() == "jwt_signing_key")
1064 }
1065 }
1066
1067 fn test_provider(base_url: &str) -> WebChatProvider {
1068 let backend = Arc::new(TestSecretsBackend::new("test-signing-key".to_string()));
1069 let scope = Scope::new("global", "webchat", None).expect("valid signing scope");
1070 WebChatProvider::new(Config::with_base_url(base_url), backend).with_signing_scope(scope)
1071 }
1072
1073 #[tokio::test]
1074 async fn user_activity_is_normalized_and_streamed() {
1075 let bus = Arc::new(RecordingBus::default());
1076 let sessions = Arc::new(MemorySessionStore::default());
1077 let conversations = memory_store();
1078 let provider = test_provider("http://localhost");
1079 let state = Arc::new(
1080 StandaloneState::with_store(
1081 provider,
1082 conversations.clone(),
1083 sessions.clone(),
1084 bus.clone(),
1085 )
1086 .await
1087 .expect("state"),
1088 );
1089
1090 let user_token = issue_token(&state, "user-1").await;
1091 let (conversation_id, conversation_token) = start_conversation(&state, &user_token).await;
1092
1093 let mut subscriber = conversations
1094 .subscribe(&conversation_id)
1095 .await
1096 .expect("subscribe");
1097
1098 post_user_activity(
1099 &state,
1100 &conversation_id,
1101 &conversation_token,
1102 json!({
1103 "type": "message",
1104 "text": "hello"
1105 }),
1106 )
1107 .await;
1108
1109 let stored = subscriber.recv().await.expect("activity");
1110 assert_eq!(stored.activity.text.as_deref(), Some("hello"));
1111
1112 let events = bus.take().await;
1113 assert_eq!(events.len(), 1);
1114 match &events[0] {
1115 GreenticEvent::IncomingMessage(msg) => match &msg.payload {
1116 MessagePayload::Text { text, .. } => assert_eq!(text, "hello"),
1117 other => panic!("unexpected payload: {other:?}"),
1118 },
1119 }
1120 }
1121
1122 #[tokio::test]
1123 async fn invoke_is_normalized_to_event() {
1124 let bus = Arc::new(RecordingBus::default());
1125 let sessions = Arc::new(MemorySessionStore::default());
1126 let conversations = memory_store();
1127 let provider = test_provider("http://localhost");
1128 let state = Arc::new(
1129 StandaloneState::with_store(provider, conversations.clone(), sessions, bus.clone())
1130 .await
1131 .expect("state"),
1132 );
1133 let user_token = issue_token(&state, "user-2").await;
1134 let (conversation_id, conversation_token) = start_conversation(&state, &user_token).await;
1135
1136 post_user_activity(
1137 &state,
1138 &conversation_id,
1139 &conversation_token,
1140 json!({
1141 "type": "invoke",
1142 "name": "adaptiveCard/action",
1143 "value": {"foo": "bar"}
1144 }),
1145 )
1146 .await;
1147
1148 let events = bus.take().await;
1149 assert_eq!(events.len(), 1);
1150 match &events[0] {
1151 GreenticEvent::IncomingMessage(msg) => match &msg.payload {
1152 MessagePayload::Event { name, .. } => assert_eq!(name, "adaptiveCard/action"),
1153 other => panic!("expected event payload, got {other:?}"),
1154 },
1155 }
1156 }
1157
1158 #[tokio::test]
1159 async fn admin_bot_activity_appends_and_streams() {
1160 let bus = Arc::new(RecordingBus::default());
1161 let sessions = Arc::new(MemorySessionStore::default());
1162 let conversations = memory_store();
1163 let provider = test_provider("http://localhost");
1164 let state = Arc::new(
1165 StandaloneState::with_store(
1166 provider,
1167 conversations.clone(),
1168 sessions.clone(),
1169 bus.clone(),
1170 )
1171 .await
1172 .expect("state"),
1173 );
1174
1175 let user_token = issue_token(&state, "user-3").await;
1176 let (conversation_id, _) = start_conversation(&state, &user_token).await;
1177 let mut subscriber = conversations
1178 .subscribe(&conversation_id)
1179 .await
1180 .expect("subscribe");
1181
1182 let Json(response) = admin_post_activity_handler(
1183 Extension(Arc::clone(&state)),
1184 Path(AdminPath {
1185 env: "dev".to_string(),
1186 tenant: "acme".to_string(),
1187 }),
1188 Json(AdminPostActivityRequest {
1189 team: None,
1190 conversation_id: Some(conversation_id.clone()),
1191 activity: json!({
1192 "type": "message",
1193 "text": "bot says hi"
1194 }),
1195 }),
1196 )
1197 .await
1198 .expect("admin post");
1199
1200 assert_eq!(response.posted, 1);
1201 assert_eq!(response.skipped, 0);
1202
1203 let stored = subscriber.recv().await.expect("activity");
1204 assert_eq!(stored.activity.text.as_deref(), Some("bot says hi"));
1205 assert!(
1206 stored
1207 .activity
1208 .from
1209 .as_ref()
1210 .and_then(|from| from.role.as_deref())
1211 .map(|role| role.eq_ignore_ascii_case("bot"))
1212 .unwrap_or(false)
1213 );
1214
1215 let session = sessions
1216 .get(&conversation_id)
1217 .await
1218 .expect("session fetch")
1219 .expect("session exists");
1220 assert_eq!(session.watermark.as_deref(), Some("1"));
1221
1222 assert!(bus.take().await.is_empty());
1223 }
1224
1225 async fn issue_token(state: &Arc<StandaloneState>, user_id: &str) -> String {
1226 let query = TenantQuery {
1227 env: "dev".to_string(),
1228 tenant: "acme".to_string(),
1229 team: None,
1230 };
1231 let body = GenerateTokenRequest {
1232 user: Some(UserDescriptor {
1233 id: Some(user_id.to_string()),
1234 }),
1235 trusted_origins: None,
1236 };
1237 let Json(response) = generate_token_handler(
1238 Extension(Arc::clone(state)),
1239 RemoteIp(None),
1240 Query(query),
1241 Json(body),
1242 )
1243 .await
1244 .expect("generate token");
1245 response.token
1246 }
1247
1248 async fn start_conversation(state: &Arc<StandaloneState>, token: &str) -> (String, String) {
1249 let mut headers = HeaderMap::new();
1250 headers.insert(
1251 http::header::AUTHORIZATION,
1252 HeaderValue::from_str(&format!("Bearer {token}")).unwrap(),
1253 );
1254 let Json(response) = start_conversation_handler(Extension(Arc::clone(state)), headers)
1255 .await
1256 .expect("start conversation");
1257 (response.conversation_id, response.token)
1258 }
1259
1260 async fn post_user_activity(
1261 state: &Arc<StandaloneState>,
1262 conversation_id: &str,
1263 token: &str,
1264 body: serde_json::Value,
1265 ) {
1266 let mut headers = HeaderMap::new();
1267 headers.insert(
1268 http::header::AUTHORIZATION,
1269 HeaderValue::from_str(&format!("Bearer {token}")).unwrap(),
1270 );
1271 headers.insert(
1272 http::header::CONTENT_TYPE,
1273 HeaderValue::from_static("application/json"),
1274 );
1275 let response = post_activity_handler(
1276 Extension(Arc::clone(state)),
1277 Path(ConversationPath {
1278 id: conversation_id.to_string(),
1279 }),
1280 headers,
1281 Json(body),
1282 )
1283 .await
1284 .expect("post activity");
1285 let response = response.into_response();
1286 assert_eq!(response.status(), StatusCode::CREATED);
1287 }
1288
1289 #[derive(Default)]
1290 struct RecordingBus {
1291 events: Mutex<Vec<GreenticEvent>>,
1292 }
1293
1294 impl RecordingBus {
1295 async fn take(&self) -> Vec<GreenticEvent> {
1296 let mut guard = self.events.lock().await;
1297 std::mem::take(&mut *guard)
1298 }
1299 }
1300
1301 #[async_trait::async_trait]
1302 impl EventBus for RecordingBus {
1303 async fn publish(&self, _subject: &Subject, event: &GreenticEvent) -> anyhow::Result<()> {
1304 self.events.lock().await.push(event.clone());
1305 Ok(())
1306 }
1307 }
1308}