gsm_core/platforms/webchat/
standalone.rs

1//! Standalone Direct Line server implementation.
2//!
3//! This module exposes a Direct Line–compatible surface area that aligns with
4//! Microsoft Bot Framework expectations while keeping all state within the
5//! Greentic stack. Tokens are minted locally and conversations stream through
6//! the in-memory [`ConversationStore`].
7
8use std::{
9    collections::HashMap,
10    net::{IpAddr, SocketAddr},
11    sync::{Arc, Mutex},
12    time::{Duration, Instant},
13};
14
15use anyhow::Context as _;
16use axum::{
17    Json, Router,
18    extract::{
19        ConnectInfo, Extension, FromRequestParts, Path, Query, WebSocketUpgrade,
20        ws::{Message, WebSocket},
21    },
22    http::{HeaderMap, StatusCode, request::Parts},
23    response::IntoResponse,
24    routing::{get, post},
25};
26use serde::{Deserialize, Serialize};
27use tracing::{debug, error, warn};
28use uuid::Uuid;
29
30use super::{
31    WebChatProvider,
32    auth::{self as jwt, Claims, TenantClaims},
33    bus::{NoopBus, SharedBus},
34    conversation::{
35        Activity, ChannelAccount, ConversationAccount, SharedConversationStore, StoredActivity,
36        memory_store,
37    },
38    ingress,
39    session::{MemorySessionStore, SharedSessionStore, WebchatSession},
40};
41use greentic_types::{EnvId, TeamId, TenantCtx, TenantId};
42
43const TOKEN_TTL_SECONDS: u64 = 1_800;
44const RATE_LIMIT_CAPACITY: usize = 5;
45const RATE_LIMIT_WINDOW: Duration = Duration::from_secs(60);
46
47struct RemoteIp(Option<IpAddr>);
48
49impl<S> FromRequestParts<S> for RemoteIp
50where
51    S: Send + Sync,
52{
53    type Rejection = std::convert::Infallible;
54
55    fn from_request_parts(
56        parts: &mut Parts,
57        _state: &S,
58    ) -> impl std::future::Future<Output = Result<Self, Self::Rejection>> + Send {
59        let ip = parts
60            .extensions
61            .get::<ConnectInfo<SocketAddr>>()
62            .map(|ConnectInfo(addr)| addr.ip());
63        std::future::ready(Ok(RemoteIp(ip)))
64    }
65}
66
67/// Shared state for the standalone Direct Line server.
68#[derive(Clone)]
69pub struct StandaloneState {
70    pub provider: WebChatProvider,
71    pub conversations: SharedConversationStore,
72    pub sessions: SharedSessionStore,
73    pub bus: SharedBus,
74    rate_limiter: Arc<IpRateLimiter>,
75}
76
77impl StandaloneState {
78    /// Builds a new state wired to the provided conversation store.
79    pub async fn with_store(
80        provider: WebChatProvider,
81        conversations: SharedConversationStore,
82        sessions: SharedSessionStore,
83        bus: SharedBus,
84    ) -> anyhow::Result<Self> {
85        let signing_keys = provider
86            .signing_keys()
87            .await
88            .context("failed to resolve Direct Line signing key")?;
89        if let Err(err) = jwt::install_keys(signing_keys) {
90            debug!("jwt keys already installed: {err}");
91        }
92        Ok(Self {
93            provider,
94            conversations,
95            sessions,
96            bus,
97            rate_limiter: Arc::new(IpRateLimiter::new(RATE_LIMIT_CAPACITY, RATE_LIMIT_WINDOW)),
98        })
99    }
100
101    /// Builds a new state with the default in-memory conversation store.
102    pub async fn new(provider: WebChatProvider) -> anyhow::Result<Self> {
103        Self::with_store(
104            provider,
105            memory_store(),
106            Arc::new(MemorySessionStore::default()),
107            Arc::new(NoopBus),
108        )
109        .await
110    }
111}
112
113fn router_blueprint() -> Router {
114    Router::new()
115        .route(
116            "/v3/directline/tokens/generate",
117            post(generate_token_handler),
118        )
119        .route(
120            "/v3/directline/conversations",
121            post(start_conversation_handler),
122        )
123        .route(
124            "/v3/directline/conversations/{id}/activities",
125            get(list_activities_handler).post(post_activity_handler),
126        )
127        .route(
128            "/v3/directline/conversations/{id}/stream",
129            get(conversation_stream_handler),
130        )
131        .route(
132            "/webchat/admin/{env}/{tenant}/post-activity",
133            post(admin_post_activity_handler),
134        )
135}
136
137/// Builds the Axum router that serves `/v3/directline` endpoints.
138pub fn router(state: Arc<StandaloneState>) -> Router {
139    router_blueprint().layer(axum::Extension(state))
140}
141
142#[allow(dead_code)]
143const fn assert_state_bounds<T: Send + Sync + Clone>() {}
144const _: () = {
145    assert_state_bounds::<StandaloneState>();
146};
147
148#[doc(hidden)]
149#[derive(Debug, Deserialize)]
150pub struct TenantQuery {
151    env: String,
152    tenant: String,
153    #[serde(default)]
154    team: Option<String>,
155}
156
157#[doc(hidden)]
158#[derive(Debug, Deserialize)]
159#[serde(rename_all = "camelCase")]
160pub struct GenerateTokenRequest {
161    #[serde(default)]
162    user: Option<UserDescriptor>,
163    #[allow(dead_code)]
164    #[serde(default)]
165    trusted_origins: Option<Vec<String>>,
166}
167
168#[doc(hidden)]
169#[derive(Debug, Deserialize)]
170pub struct UserDescriptor {
171    #[serde(default)]
172    id: Option<String>,
173}
174
175#[derive(Debug, Serialize, Deserialize)]
176struct TokenResponse {
177    token: String,
178    expires_in: u64,
179}
180
181async fn generate_token_handler(
182    Extension(state): Extension<Arc<StandaloneState>>,
183    remote: RemoteIp,
184    Query(query): Query<TenantQuery>,
185    Json(body): Json<GenerateTokenRequest>,
186) -> Result<Json<TokenResponse>, StatusCode> {
187    let ip = remote.0.unwrap_or_else(|| IpAddr::from([127, 0, 0, 1]));
188    if !state.rate_limiter.check(ip) {
189        return Err(StatusCode::TOO_MANY_REQUESTS);
190    }
191
192    let tenant_ctx = tenant_ctx_from_query(&query)?;
193    let subject = body
194        .user
195        .and_then(|user| user.id)
196        .filter(|id| !id.trim().is_empty())
197        .unwrap_or_else(|| format!("user:{}", Uuid::new_v4()));
198
199    let claims = Claims::new(
200        subject,
201        tenant_claims_from_ctx(&tenant_ctx),
202        jwt::ttl(TOKEN_TTL_SECONDS),
203    );
204    let token = jwt::sign(&claims).map_err(|err| map_error("sign token", err))?;
205    Ok(Json(TokenResponse {
206        token,
207        expires_in: TOKEN_TTL_SECONDS,
208    }))
209}
210
211#[doc(hidden)]
212#[derive(Debug, Deserialize)]
213pub struct ConversationPath {
214    id: String,
215}
216
217#[derive(Debug, Serialize, Deserialize)]
218#[serde(rename_all = "camelCase")]
219struct ConversationResponse {
220    conversation_id: String,
221    token: String,
222    #[serde(skip_serializing_if = "Option::is_none")]
223    stream_url: Option<String>,
224}
225
226async fn start_conversation_handler(
227    Extension(state): Extension<Arc<StandaloneState>>,
228    headers: HeaderMap,
229) -> Result<Json<ConversationResponse>, StatusCode> {
230    let token = extract_bearer(&headers)?;
231    let claims = jwt::verify(token).map_err(|_| StatusCode::UNAUTHORIZED)?;
232    if claims.conv.is_some() {
233        return Err(StatusCode::BAD_REQUEST);
234    }
235    let tenant_ctx = tenant_ctx_from_claims(&claims)?;
236    let conversation_id = Uuid::new_v4().to_string();
237    state
238        .conversations
239        .create(&conversation_id, tenant_ctx.clone())
240        .await
241        .map_err(map_store_error)?;
242
243    let conversation_token = Claims::new(
244        claims.sub.clone(),
245        claims.ctx.clone(),
246        jwt::ttl(TOKEN_TTL_SECONDS),
247    )
248    .with_conversation(conversation_id.clone());
249    let token = jwt::sign(&conversation_token).map_err(|err| map_error("sign token", err))?;
250
251    if let Err(err) = state
252        .sessions
253        .upsert(WebchatSession::new(
254            conversation_id.clone(),
255            tenant_ctx.clone(),
256            token.clone(),
257        ))
258        .await
259    {
260        error!(error = %err, "failed to persist webchat session");
261        return Err(StatusCode::INTERNAL_SERVER_ERROR);
262    }
263
264    let stream_url = stream_url_from_headers(&state, &headers, &conversation_id, &token);
265    Ok(Json(ConversationResponse {
266        conversation_id,
267        token,
268        stream_url,
269    }))
270}
271
272#[doc(hidden)]
273#[derive(Debug, Deserialize)]
274pub struct ActivitiesQuery {
275    #[serde(default)]
276    watermark: Option<String>,
277}
278
279#[derive(Debug, Serialize)]
280struct ActivitiesResponse {
281    activities: Vec<Activity>,
282    watermark: String,
283}
284
285async fn list_activities_handler(
286    Extension(state): Extension<Arc<StandaloneState>>,
287    Path(path): Path<ConversationPath>,
288    headers: HeaderMap,
289    Query(query): Query<ActivitiesQuery>,
290) -> Result<Json<ActivitiesResponse>, StatusCode> {
291    let _claims = validate_conversation_token(state.as_ref(), &path.id, &headers).await?;
292    let watermark = query
293        .watermark
294        .as_deref()
295        .and_then(|value| {
296            let trimmed = value.trim();
297            if trimmed.is_empty() {
298                None
299            } else {
300                Some(trimmed)
301            }
302        })
303        .map(parse_watermark)
304        .transpose()?;
305    let page = state
306        .conversations
307        .activities(&path.id, watermark)
308        .await
309        .map_err(map_store_error)?;
310    let activities: Vec<Activity> = page
311        .activities
312        .into_iter()
313        .map(|entry| entry.activity)
314        .collect();
315    Ok(Json(ActivitiesResponse {
316        activities,
317        watermark: page.watermark.to_string(),
318    }))
319}
320
321#[derive(Debug, Serialize)]
322struct ActivityAck {
323    id: String,
324}
325
326async fn post_activity_handler(
327    Extension(state): Extension<Arc<StandaloneState>>,
328    Path(path): Path<ConversationPath>,
329    headers: HeaderMap,
330    Json(payload): Json<serde_json::Value>,
331) -> Result<impl IntoResponse, StatusCode> {
332    let claims = validate_conversation_token(state.as_ref(), &path.id, &headers).await?;
333    let tenant_ctx = state
334        .conversations
335        .tenant_ctx(&path.id)
336        .await
337        .map_err(map_store_error)?;
338
339    let mut bus_activity = payload.clone();
340    apply_user_json_defaults(&mut bus_activity, &path.id, &claims.sub);
341
342    let mut activity: Activity =
343        serde_json::from_value(payload).map_err(|_| StatusCode::BAD_REQUEST)?;
344    apply_user_defaults(&mut activity, &path.id, &claims.sub);
345
346    let ingress = ingress::Ingress::new(state.bus.clone(), state.sessions.clone());
347    if let Err(err) = ingress
348        .publish_incoming(&bus_activity, &tenant_ctx, &path.id)
349        .await
350    {
351        warn!(error = %err, "failed to publish incoming activity");
352    }
353
354    let stored = state
355        .conversations
356        .append(&path.id, activity)
357        .await
358        .map_err(map_store_error)?;
359
360    if let Err(err) = state
361        .sessions
362        .update_watermark(&path.id, Some((stored.watermark + 1).to_string()))
363        .await
364    {
365        warn!(error = %err, "failed to update watermark");
366    }
367    Ok((
368        StatusCode::CREATED,
369        Json(ActivityAck {
370            id: stored.activity.id,
371        }),
372    ))
373}
374
375#[derive(Debug, Deserialize)]
376pub struct AdminPath {
377    pub env: String,
378    pub tenant: String,
379}
380
381#[derive(Debug, Deserialize)]
382pub struct AdminPostActivityRequest {
383    #[serde(default)]
384    pub team: Option<String>,
385    #[serde(rename = "conversation_id", default)]
386    pub conversation_id: Option<String>,
387    pub activity: serde_json::Value,
388}
389
390#[derive(Debug, Serialize)]
391pub struct AdminPostActivityResponse {
392    pub posted: usize,
393    pub skipped: usize,
394}
395
396async fn admin_post_activity_handler(
397    Extension(state): Extension<Arc<StandaloneState>>,
398    Path(path): Path<AdminPath>,
399    Json(body): Json<AdminPostActivityRequest>,
400) -> Result<Json<AdminPostActivityResponse>, StatusCode> {
401    let AdminPostActivityRequest {
402        team,
403        conversation_id,
404        activity,
405    } = body;
406
407    if !activity.is_object() {
408        return Err(StatusCode::BAD_REQUEST);
409    }
410
411    let base_activity: Activity =
412        serde_json::from_value(activity).map_err(|_| StatusCode::BAD_REQUEST)?;
413    let team_filter = team.as_deref();
414
415    if let Some(conversation_id) = conversation_id {
416        let session = state
417            .sessions
418            .get(&conversation_id)
419            .await
420            .map_err(|err| map_error("load session", err))?
421            .ok_or(StatusCode::NOT_FOUND)?;
422
423        if !session
424            .tenant_ctx
425            .env
426            .as_ref()
427            .eq_ignore_ascii_case(&path.env)
428            || !session
429                .tenant_ctx
430                .tenant
431                .as_ref()
432                .eq_ignore_ascii_case(&path.tenant)
433        {
434            return Err(StatusCode::NOT_FOUND);
435        }
436
437        if let Some(team) = team_filter
438            && !session
439                .tenant_ctx
440                .team
441                .as_ref()
442                .map(|value| value.as_ref().eq_ignore_ascii_case(team))
443                .unwrap_or(false)
444        {
445            return Err(StatusCode::NOT_FOUND);
446        }
447
448        if !session.proactive_ok {
449            return Err(StatusCode::BAD_REQUEST);
450        }
451
452        append_bot_activity(state.as_ref(), &conversation_id, &base_activity).await?;
453
454        return Ok(Json(AdminPostActivityResponse {
455            posted: 1,
456            skipped: 0,
457        }));
458    }
459
460    let sessions = state
461        .sessions
462        .list_by_tenant(&path.env, &path.tenant, team_filter)
463        .await
464        .map_err(|err| map_error("list sessions", err))?;
465
466    if sessions.is_empty() {
467        return Err(StatusCode::NOT_FOUND);
468    }
469
470    let mut posted = 0usize;
471    let mut skipped = 0usize;
472    for session in sessions {
473        if !session.proactive_ok {
474            skipped += 1;
475            continue;
476        }
477
478        match append_bot_activity(state.as_ref(), &session.conversation_id, &base_activity).await {
479            Ok(()) => posted += 1,
480            Err(StatusCode::NOT_FOUND) => {
481                skipped += 1;
482                warn!(
483                    conversation = %session.conversation_id,
484                    "conversation not found while appending activity"
485                );
486            }
487            Err(code) => return Err(code),
488        }
489    }
490
491    if posted == 0 {
492        return Err(StatusCode::NOT_FOUND);
493    }
494
495    Ok(Json(AdminPostActivityResponse { posted, skipped }))
496}
497
498#[derive(Debug, Deserialize)]
499struct StreamQuery {
500    t: String,
501    #[serde(default)]
502    watermark: Option<String>,
503}
504
505async fn conversation_stream_handler(
506    Extension(state): Extension<Arc<StandaloneState>>,
507    Path(path): Path<ConversationPath>,
508    Query(query): Query<StreamQuery>,
509    ws: WebSocketUpgrade,
510) -> Result<impl IntoResponse, StatusCode> {
511    let claims = jwt::verify(&query.t).map_err(|_| StatusCode::UNAUTHORIZED)?;
512    ensure_conversation_access(state.as_ref(), &path.id, &claims).await?;
513    let watermark = query
514        .watermark
515        .as_deref()
516        .map(parse_watermark)
517        .transpose()?;
518    Ok(ws.on_upgrade(move |socket| async move {
519        if let Err(err) = run_websocket(socket, state, path.id, watermark).await {
520            warn!("websocket closed with error: {err:?}");
521        }
522    }))
523}
524
525async fn run_websocket(
526    mut socket: WebSocket,
527    state: Arc<StandaloneState>,
528    conversation_id: String,
529    watermark: Option<u64>,
530) -> anyhow::Result<()> {
531    let page = state
532        .conversations
533        .activities(&conversation_id, watermark)
534        .await
535        .map_err(|err| anyhow::anyhow!(err.to_string()))?;
536    let initial: Vec<StoredActivity> = page.activities.iter().cloned().collect();
537    if !initial.is_empty() {
538        send_envelope(&mut socket, &initial, page.watermark).await?;
539    }
540    let mut consecutive_failures: u32 = 0;
541    const SEND_FAILURE_THRESHOLD: u32 = 5;
542    let mut current_watermark = page.watermark;
543    let mut subscriber = state
544        .conversations
545        .subscribe(&conversation_id)
546        .await
547        .map_err(|err| anyhow::anyhow!(err.to_string()))?;
548
549    loop {
550        tokio::select! {
551            message = socket.recv() => {
552                match message {
553                    Some(Ok(Message::Close(_))) | None => break,
554                    Some(Ok(_)) => continue,
555                    Some(Err(err)) => {
556                        warn!("websocket recv error: {err}");
557                        break;
558                    }
559                }
560            }
561            received = subscriber.recv() => {
562                match received {
563                    Ok(activity) => {
564                        if let Err(err) = send_envelope(&mut socket, std::slice::from_ref(&activity), activity.watermark + 1).await {
565                            consecutive_failures = consecutive_failures.saturating_add(1);
566                            warn!(error = ?err, consecutive_failures, "websocket send error");
567                            if consecutive_failures >= SEND_FAILURE_THRESHOLD {
568                                warn!("terminating websocket due to repeated send failures");
569                                break;
570                            }
571                            continue;
572                        } else {
573                            consecutive_failures = 0;
574                        }
575                        current_watermark = activity.watermark + 1;
576                    }
577                    Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
578                    Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => {
579                        let page = state.conversations.activities(&conversation_id, Some(current_watermark))
580                            .await
581                            .map_err(|err| anyhow::anyhow!(err.to_string()))?;
582                        let resend: Vec<StoredActivity> = page.activities.iter().cloned().collect();
583                        if !resend.is_empty() {
584                            if let Err(err) = send_envelope(&mut socket, &resend, page.watermark).await {
585                                consecutive_failures = consecutive_failures.saturating_add(1);
586                                warn!(error = ?err, consecutive_failures, "websocket resend error");
587                                if consecutive_failures >= SEND_FAILURE_THRESHOLD {
588                                    warn!("terminating websocket due to repeated resend failures");
589                                    break;
590                                }
591                                continue;
592                            } else {
593                                consecutive_failures = 0;
594                            }
595                        }
596                        current_watermark = page.watermark;
597                    }
598                }
599            }
600        }
601    }
602    Ok(())
603}
604
605#[doc(hidden)]
606pub async fn test_run_websocket_public(
607    socket: WebSocket,
608    state: Arc<StandaloneState>,
609    conversation_id: String,
610    watermark: Option<u64>,
611) -> anyhow::Result<()> {
612    run_websocket(socket, state, conversation_id, watermark).await
613}
614
615async fn send_envelope(
616    socket: &mut WebSocket,
617    activities: &[StoredActivity],
618    watermark: u64,
619) -> anyhow::Result<()> {
620    if activities.is_empty() {
621        return Ok(());
622    }
623    let payload = envelope_payload(activities, watermark)?;
624    socket.send(Message::Text(payload.into())).await?;
625    Ok(())
626}
627
628fn envelope_payload(activities: &[StoredActivity], watermark: u64) -> anyhow::Result<String> {
629    let acts: Vec<Activity> = activities
630        .iter()
631        .map(|entry| entry.activity.clone())
632        .collect();
633    Ok(serde_json::to_string(&serde_json::json!({
634        "activities": acts,
635        "watermark": watermark.to_string(),
636    }))?)
637}
638
639fn tenant_ctx_from_query(query: &TenantQuery) -> Result<TenantCtx, StatusCode> {
640    let env = query.env.trim();
641    let tenant = query.tenant.trim();
642    if env.is_empty() || tenant.is_empty() {
643        return Err(StatusCode::BAD_REQUEST);
644    }
645    let mut ctx = TenantCtx::new(EnvId(env.to_string()), TenantId(tenant.to_string()));
646    if let Some(team) = &query.team {
647        let team = team.trim();
648        if !team.is_empty() {
649            ctx = ctx.with_team(Some(TeamId(team.to_string())));
650        }
651    }
652    Ok(ctx)
653}
654
655fn tenant_ctx_from_claims(claims: &Claims) -> Result<TenantCtx, StatusCode> {
656    let mut ctx = TenantCtx::new(
657        EnvId(claims.ctx.env.clone()),
658        TenantId(claims.ctx.tenant.clone()),
659    );
660    if let Some(team) = &claims.ctx.team {
661        ctx = ctx.with_team(Some(TeamId(team.clone())));
662    }
663    Ok(ctx)
664}
665
666fn tenant_claims_from_ctx(ctx: &TenantCtx) -> TenantClaims {
667    TenantClaims {
668        env: ctx.env.as_ref().to_string(),
669        tenant: ctx.tenant.as_ref().to_string(),
670        team: ctx.team.as_ref().map(|team| team.as_ref().to_string()),
671    }
672}
673
674fn extract_bearer(headers: &HeaderMap) -> Result<&str, StatusCode> {
675    let value = headers
676        .get(axum::http::header::AUTHORIZATION)
677        .and_then(|value| value.to_str().ok())
678        .ok_or(StatusCode::UNAUTHORIZED)?;
679    if let Some(rest) = value.strip_prefix("Bearer ") {
680        Ok(rest.trim())
681    } else if let Some(rest) = value.strip_prefix("bearer ") {
682        Ok(rest.trim())
683    } else {
684        Err(StatusCode::UNAUTHORIZED)
685    }
686}
687
688fn parse_watermark(value: &str) -> Result<u64, StatusCode> {
689    value.parse::<u64>().map_err(|_| StatusCode::BAD_REQUEST)
690}
691
692fn stream_url_from_headers(
693    state: &StandaloneState,
694    headers: &HeaderMap,
695    conversation_id: &str,
696    token: &str,
697) -> Option<String> {
698    if let Some(host) = headers
699        .get(axum::http::header::HOST)
700        .and_then(|value| value.to_str().ok())
701    {
702        let scheme = if state
703            .provider
704            .config()
705            .direct_line_base()
706            .starts_with("https://")
707        {
708            "wss"
709        } else {
710            "ws"
711        };
712        Some(format!(
713            "{scheme}://{host}/v3/directline/conversations/{conversation_id}/stream?t={token}"
714        ))
715    } else {
716        None
717    }
718}
719
720async fn validate_conversation_token(
721    state: &StandaloneState,
722    conversation_id: &str,
723    headers: &HeaderMap,
724) -> Result<Claims, StatusCode> {
725    let token = extract_bearer(headers)?;
726    let claims = jwt::verify(token).map_err(|_| StatusCode::UNAUTHORIZED)?;
727    ensure_conversation_access(state, conversation_id, &claims).await?;
728    Ok(claims)
729}
730
731async fn ensure_conversation_access(
732    state: &StandaloneState,
733    conversation_id: &str,
734    claims: &Claims,
735) -> Result<(), StatusCode> {
736    if !claims.has_conversation(conversation_id) {
737        return Err(StatusCode::FORBIDDEN);
738    }
739    let claimed_ctx = tenant_ctx_from_claims(claims)?;
740    let stored_ctx = state
741        .conversations
742        .tenant_ctx(conversation_id)
743        .await
744        .map_err(map_store_error)?;
745    if stored_ctx != claimed_ctx {
746        return Err(StatusCode::FORBIDDEN);
747    }
748    Ok(())
749}
750
751fn normalise_activity(activity: &mut Activity, conversation_id: &str, subject: &str) {
752    if activity
753        .from
754        .as_ref()
755        .map(|from| from.id.trim().is_empty())
756        .unwrap_or(true)
757    {
758        activity.from = Some(ChannelAccount {
759            id: subject.to_string(),
760            name: None,
761            role: Some("user".into()),
762        });
763    }
764    if activity.conversation.is_none() {
765        activity.conversation = Some(ConversationAccount {
766            id: conversation_id.to_string(),
767        });
768    }
769}
770
771fn apply_bot_defaults(activity: &mut Activity, conversation_id: &str) {
772    normalise_activity(activity, conversation_id, "bot");
773    if let Some(from) = activity.from.as_mut() {
774        if from.id.trim().is_empty() {
775            from.id = "bot".into();
776        }
777        from.role = Some("bot".into());
778    } else {
779        activity.from = Some(ChannelAccount {
780            id: "bot".into(),
781            name: None,
782            role: Some("bot".into()),
783        });
784    }
785}
786
787fn apply_user_json_defaults(
788    activity: &mut serde_json::Value,
789    conversation_id: &str,
790    subject: &str,
791) {
792    let Some(obj) = activity.as_object_mut() else {
793        return;
794    };
795
796    let from_entry = obj
797        .entry("from".to_string())
798        .or_insert_with(|| serde_json::Value::Object(serde_json::Map::new()));
799    if !from_entry.is_object() {
800        *from_entry = serde_json::Value::Object(serde_json::Map::new());
801    }
802    if let Some(from_obj) = from_entry.as_object_mut() {
803        let id_entry = from_obj
804            .entry("id".to_string())
805            .or_insert_with(|| serde_json::Value::String(subject.to_string()));
806        if id_entry
807            .as_str()
808            .map(|value| value.trim().is_empty())
809            .unwrap_or(true)
810        {
811            *id_entry = serde_json::Value::String(subject.to_string());
812        }
813        let role_entry = from_obj
814            .entry("role".to_string())
815            .or_insert_with(|| serde_json::Value::String("user".into()));
816        if !role_entry
817            .as_str()
818            .map(|role| role.eq_ignore_ascii_case("user"))
819            .unwrap_or(false)
820        {
821            *role_entry = serde_json::Value::String("user".into());
822        }
823    }
824
825    let conversation_entry = obj.entry("conversation".to_string()).or_insert_with(|| {
826        let mut conv = serde_json::Map::new();
827        conv.insert(
828            "id".to_string(),
829            serde_json::Value::String(conversation_id.to_string()),
830        );
831        serde_json::Value::Object(conv)
832    });
833    if let Some(conv_obj) = conversation_entry.as_object_mut() {
834        let id_entry = conv_obj
835            .entry("id".to_string())
836            .or_insert_with(|| serde_json::Value::String(conversation_id.to_string()));
837        if id_entry
838            .as_str()
839            .map(|value| value.trim().is_empty())
840            .unwrap_or(true)
841        {
842            *id_entry = serde_json::Value::String(conversation_id.to_string());
843        }
844    } else {
845        let mut conv = serde_json::Map::new();
846        conv.insert(
847            "id".to_string(),
848            serde_json::Value::String(conversation_id.to_string()),
849        );
850        *conversation_entry = serde_json::Value::Object(conv);
851    }
852}
853
854fn apply_user_defaults(activity: &mut Activity, conversation_id: &str, subject: &str) {
855    normalise_activity(activity, conversation_id, subject);
856    if let Some(from) = activity.from.as_mut() {
857        if from
858            .role
859            .as_deref()
860            .map(|role| role.eq_ignore_ascii_case("user"))
861            .unwrap_or(true)
862        {
863            from.role = Some("user".into());
864        }
865        if from.id.trim().is_empty() {
866            from.id = subject.to_string();
867        }
868    } else {
869        activity.from = Some(ChannelAccount {
870            id: subject.to_string(),
871            name: None,
872            role: Some("user".into()),
873        });
874    }
875}
876
877async fn append_bot_activity(
878    state: &StandaloneState,
879    conversation_id: &str,
880    base_activity: &Activity,
881) -> Result<(), StatusCode> {
882    let mut activity = base_activity.clone();
883    apply_bot_defaults(&mut activity, conversation_id);
884    let stored = state
885        .conversations
886        .append(conversation_id, activity)
887        .await
888        .map_err(map_store_error)?;
889
890    if let Err(err) = state
891        .sessions
892        .update_watermark(conversation_id, Some((stored.watermark + 1).to_string()))
893        .await
894    {
895        warn!(
896            error = %err,
897            conversation = %conversation_id,
898            "failed to update watermark"
899        );
900    }
901
902    Ok(())
903}
904
905fn map_store_error(err: super::conversation::StoreError) -> StatusCode {
906    match err {
907        super::conversation::StoreError::AlreadyExists(_) => StatusCode::CONFLICT,
908        super::conversation::StoreError::NotFound(_) => StatusCode::NOT_FOUND,
909        super::conversation::StoreError::QuotaExceeded(_) => StatusCode::TOO_MANY_REQUESTS,
910        super::conversation::StoreError::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
911    }
912}
913
914fn map_error(context: &str, err: anyhow::Error) -> StatusCode {
915    error!("{context}: {err}");
916    StatusCode::INTERNAL_SERVER_ERROR
917}
918
919struct IpRateLimiter {
920    capacity: f64,
921    refill_per_sec: f64,
922    buckets: Mutex<HashMap<IpAddr, TokenBucket>>,
923}
924
925struct TokenBucket {
926    tokens: f64,
927    last: Instant,
928}
929
930impl IpRateLimiter {
931    fn new(capacity: usize, window: Duration) -> Self {
932        let capacity = capacity as f64;
933        let refill_per_sec = if window.is_zero() {
934            capacity
935        } else {
936            capacity / window.as_secs_f64()
937        };
938        Self {
939            capacity,
940            refill_per_sec,
941            buckets: Mutex::new(HashMap::new()),
942        }
943    }
944
945    fn check(&self, ip: IpAddr) -> bool {
946        let mut guard = self.buckets.lock().expect("rate limiter mutex poisoned");
947        let entry = guard.entry(ip).or_insert(TokenBucket {
948            tokens: self.capacity,
949            last: Instant::now(),
950        });
951        let now = Instant::now();
952        let elapsed = now.saturating_duration_since(entry.last).as_secs_f64();
953        if elapsed > 0.0 {
954            entry.tokens = (entry.tokens + elapsed * self.refill_per_sec).min(self.capacity);
955            entry.last = now;
956        }
957        if entry.tokens >= 1.0 {
958            entry.tokens -= 1.0;
959            true
960        } else {
961            false
962        }
963    }
964}
965
966#[cfg(test)]
967mod tests {
968    use super::*;
969
970    use crate::platforms::webchat::{
971        WebChatProvider,
972        bus::{EventBus, Subject},
973        config::Config,
974        session::WebchatSessionStore,
975        types::{GreenticEvent, MessagePayload},
976    };
977    use axum::{
978        Extension, Json,
979        extract::{Path, Query},
980        http::{self, HeaderMap, HeaderValue, StatusCode},
981        response::IntoResponse,
982    };
983    use greentic_secrets::spec::{
984        Scope, SecretUri, SecretsBackend, VersionedSecret, helpers::record_from_plain,
985    };
986    use serde_json::json;
987    use tokio::sync::Mutex;
988
989    #[derive(Clone)]
990    struct TestSecretsBackend {
991        secret: String,
992    }
993
994    impl TestSecretsBackend {
995        fn new(secret: String) -> Self {
996            Self { secret }
997        }
998    }
999
1000    impl SecretsBackend for TestSecretsBackend {
1001        fn put(
1002            &self,
1003            _record: greentic_secrets::spec::SecretRecord,
1004        ) -> greentic_secrets::spec::Result<greentic_secrets::spec::SecretVersion> {
1005            unimplemented!("test backend does not support writes")
1006        }
1007
1008        fn get(
1009            &self,
1010            uri: &SecretUri,
1011            _version: Option<u64>,
1012        ) -> greentic_secrets::spec::Result<Option<VersionedSecret>> {
1013            if uri.category() == "webchat" && uri.name() == "jwt_signing_key" {
1014                let record = record_from_plain(self.secret.clone());
1015                Ok(Some(VersionedSecret {
1016                    version: 1,
1017                    deleted: false,
1018                    record: Some(record),
1019                }))
1020            } else {
1021                Ok(None)
1022            }
1023        }
1024
1025        fn list(
1026            &self,
1027            _scope: &greentic_secrets::spec::Scope,
1028            _category_prefix: Option<&str>,
1029            _name_prefix: Option<&str>,
1030        ) -> greentic_secrets::spec::Result<Vec<greentic_secrets::spec::SecretListItem>> {
1031            unimplemented!("test backend does not support listing")
1032        }
1033
1034        fn delete(
1035            &self,
1036            _uri: &SecretUri,
1037        ) -> greentic_secrets::spec::Result<greentic_secrets::spec::SecretVersion> {
1038            unimplemented!("test backend does not support delete")
1039        }
1040
1041        fn versions(
1042            &self,
1043            _uri: &SecretUri,
1044        ) -> greentic_secrets::spec::Result<Vec<greentic_secrets::spec::SecretVersion>> {
1045            unimplemented!("test backend does not support versions")
1046        }
1047
1048        fn exists(&self, uri: &SecretUri) -> greentic_secrets::spec::Result<bool> {
1049            Ok(uri.category() == "webchat" && uri.name() == "jwt_signing_key")
1050        }
1051    }
1052
1053    fn test_provider(base_url: &str) -> WebChatProvider {
1054        let backend = Arc::new(TestSecretsBackend::new("test-signing-key".to_string()));
1055        let scope = Scope::new("global", "webchat", None).expect("valid signing scope");
1056        WebChatProvider::new(Config::with_base_url(base_url), backend).with_signing_scope(scope)
1057    }
1058
1059    #[tokio::test]
1060    async fn user_activity_is_normalized_and_streamed() {
1061        let bus = Arc::new(RecordingBus::default());
1062        let sessions = Arc::new(MemorySessionStore::default());
1063        let conversations = memory_store();
1064        let provider = test_provider("http://localhost");
1065        let state = Arc::new(
1066            StandaloneState::with_store(
1067                provider,
1068                conversations.clone(),
1069                sessions.clone(),
1070                bus.clone(),
1071            )
1072            .await
1073            .expect("state"),
1074        );
1075
1076        let user_token = issue_token(&state, "user-1").await;
1077        let (conversation_id, conversation_token) = start_conversation(&state, &user_token).await;
1078
1079        let mut subscriber = conversations
1080            .subscribe(&conversation_id)
1081            .await
1082            .expect("subscribe");
1083
1084        post_user_activity(
1085            &state,
1086            &conversation_id,
1087            &conversation_token,
1088            json!({
1089                "type": "message",
1090                "text": "hello"
1091            }),
1092        )
1093        .await;
1094
1095        let stored = subscriber.recv().await.expect("activity");
1096        assert_eq!(stored.activity.text.as_deref(), Some("hello"));
1097
1098        let events = bus.take().await;
1099        assert_eq!(events.len(), 1);
1100        match &events[0] {
1101            GreenticEvent::IncomingMessage(msg) => match &msg.payload {
1102                MessagePayload::Text { text, .. } => assert_eq!(text, "hello"),
1103                other => panic!("unexpected payload: {other:?}"),
1104            },
1105        }
1106    }
1107
1108    #[tokio::test]
1109    async fn invoke_is_normalized_to_event() {
1110        let bus = Arc::new(RecordingBus::default());
1111        let sessions = Arc::new(MemorySessionStore::default());
1112        let conversations = memory_store();
1113        let provider = test_provider("http://localhost");
1114        let state = Arc::new(
1115            StandaloneState::with_store(provider, conversations.clone(), sessions, bus.clone())
1116                .await
1117                .expect("state"),
1118        );
1119        let user_token = issue_token(&state, "user-2").await;
1120        let (conversation_id, conversation_token) = start_conversation(&state, &user_token).await;
1121
1122        post_user_activity(
1123            &state,
1124            &conversation_id,
1125            &conversation_token,
1126            json!({
1127                "type": "invoke",
1128                "name": "adaptiveCard/action",
1129                "value": {"foo": "bar"}
1130            }),
1131        )
1132        .await;
1133
1134        let events = bus.take().await;
1135        assert_eq!(events.len(), 1);
1136        match &events[0] {
1137            GreenticEvent::IncomingMessage(msg) => match &msg.payload {
1138                MessagePayload::Event { name, .. } => assert_eq!(name, "adaptiveCard/action"),
1139                other => panic!("expected event payload, got {other:?}"),
1140            },
1141        }
1142    }
1143
1144    #[tokio::test]
1145    async fn admin_bot_activity_appends_and_streams() {
1146        let bus = Arc::new(RecordingBus::default());
1147        let sessions = Arc::new(MemorySessionStore::default());
1148        let conversations = memory_store();
1149        let provider = test_provider("http://localhost");
1150        let state = Arc::new(
1151            StandaloneState::with_store(
1152                provider,
1153                conversations.clone(),
1154                sessions.clone(),
1155                bus.clone(),
1156            )
1157            .await
1158            .expect("state"),
1159        );
1160
1161        let user_token = issue_token(&state, "user-3").await;
1162        let (conversation_id, _) = start_conversation(&state, &user_token).await;
1163        let mut subscriber = conversations
1164            .subscribe(&conversation_id)
1165            .await
1166            .expect("subscribe");
1167
1168        let Json(response) = admin_post_activity_handler(
1169            Extension(Arc::clone(&state)),
1170            Path(AdminPath {
1171                env: "dev".to_string(),
1172                tenant: "acme".to_string(),
1173            }),
1174            Json(AdminPostActivityRequest {
1175                team: None,
1176                conversation_id: Some(conversation_id.clone()),
1177                activity: json!({
1178                    "type": "message",
1179                    "text": "bot says hi"
1180                }),
1181            }),
1182        )
1183        .await
1184        .expect("admin post");
1185
1186        assert_eq!(response.posted, 1);
1187        assert_eq!(response.skipped, 0);
1188
1189        let stored = subscriber.recv().await.expect("activity");
1190        assert_eq!(stored.activity.text.as_deref(), Some("bot says hi"));
1191        assert!(
1192            stored
1193                .activity
1194                .from
1195                .as_ref()
1196                .and_then(|from| from.role.as_deref())
1197                .map(|role| role.eq_ignore_ascii_case("bot"))
1198                .unwrap_or(false)
1199        );
1200
1201        let session = sessions
1202            .get(&conversation_id)
1203            .await
1204            .expect("session fetch")
1205            .expect("session exists");
1206        assert_eq!(session.watermark.as_deref(), Some("1"));
1207
1208        assert!(bus.take().await.is_empty());
1209    }
1210
1211    async fn issue_token(state: &Arc<StandaloneState>, user_id: &str) -> String {
1212        let query = TenantQuery {
1213            env: "dev".to_string(),
1214            tenant: "acme".to_string(),
1215            team: None,
1216        };
1217        let body = GenerateTokenRequest {
1218            user: Some(UserDescriptor {
1219                id: Some(user_id.to_string()),
1220            }),
1221            trusted_origins: None,
1222        };
1223        let Json(response) = generate_token_handler(
1224            Extension(Arc::clone(state)),
1225            RemoteIp(None),
1226            Query(query),
1227            Json(body),
1228        )
1229        .await
1230        .expect("generate token");
1231        response.token
1232    }
1233
1234    async fn start_conversation(state: &Arc<StandaloneState>, token: &str) -> (String, String) {
1235        let mut headers = HeaderMap::new();
1236        headers.insert(
1237            http::header::AUTHORIZATION,
1238            HeaderValue::from_str(&format!("Bearer {token}")).unwrap(),
1239        );
1240        let Json(response) = start_conversation_handler(Extension(Arc::clone(state)), headers)
1241            .await
1242            .expect("start conversation");
1243        (response.conversation_id, response.token)
1244    }
1245
1246    async fn post_user_activity(
1247        state: &Arc<StandaloneState>,
1248        conversation_id: &str,
1249        token: &str,
1250        body: serde_json::Value,
1251    ) {
1252        let mut headers = HeaderMap::new();
1253        headers.insert(
1254            http::header::AUTHORIZATION,
1255            HeaderValue::from_str(&format!("Bearer {token}")).unwrap(),
1256        );
1257        headers.insert(
1258            http::header::CONTENT_TYPE,
1259            HeaderValue::from_static("application/json"),
1260        );
1261        let response = post_activity_handler(
1262            Extension(Arc::clone(state)),
1263            Path(ConversationPath {
1264                id: conversation_id.to_string(),
1265            }),
1266            headers,
1267            Json(body),
1268        )
1269        .await
1270        .expect("post activity");
1271        let response = response.into_response();
1272        assert_eq!(response.status(), StatusCode::CREATED);
1273    }
1274
1275    #[derive(Default)]
1276    struct RecordingBus {
1277        events: Mutex<Vec<GreenticEvent>>,
1278    }
1279
1280    impl RecordingBus {
1281        async fn take(&self) -> Vec<GreenticEvent> {
1282            let mut guard = self.events.lock().await;
1283            std::mem::take(&mut *guard)
1284        }
1285    }
1286
1287    #[async_trait::async_trait]
1288    impl EventBus for RecordingBus {
1289        async fn publish(&self, _subject: &Subject, event: &GreenticEvent) -> anyhow::Result<()> {
1290            self.events.lock().await.push(event.clone());
1291            Ok(())
1292        }
1293    }
1294}