Skip to main content

gsm_core/platforms/webchat/
standalone.rs

1//! Standalone Direct Line server implementation.
2//!
3//! This module exposes a Direct Line–compatible surface area that aligns with
4//! Microsoft Bot Framework expectations while keeping all state within the
5//! Greentic stack. Tokens are minted locally and conversations stream through
6//! the in-memory [`ConversationStore`].
7
8use std::{
9    collections::HashMap,
10    net::{IpAddr, SocketAddr},
11    sync::{Arc, Mutex},
12    time::{Duration, Instant},
13};
14
15use anyhow::Context as _;
16use axum::{
17    Json, Router,
18    extract::{
19        ConnectInfo, Extension, FromRequestParts, Path, Query, WebSocketUpgrade,
20        ws::{Message, WebSocket},
21    },
22    http::{HeaderMap, StatusCode, request::Parts},
23    response::IntoResponse,
24    routing::{get, post},
25};
26use serde::{Deserialize, Serialize};
27use tracing::{debug, error, warn};
28use uuid::Uuid;
29
30use super::{
31    WebChatProvider,
32    auth::{self as jwt, Claims, TenantClaims},
33    bus::{NoopBus, SharedBus},
34    conversation::{
35        Activity, ChannelAccount, ConversationAccount, SharedConversationStore, StoredActivity,
36        memory_store,
37    },
38    ingress,
39    session::{MemorySessionStore, SharedSessionStore, WebchatSession},
40};
41use greentic_types::{EnvId, TeamId, TenantCtx, TenantId};
42
43const TOKEN_TTL_SECONDS: u64 = 1_800;
44const RATE_LIMIT_CAPACITY: usize = 5;
45const RATE_LIMIT_WINDOW: Duration = Duration::from_secs(60);
46
47struct RemoteIp(Option<IpAddr>);
48
49impl<S> FromRequestParts<S> for RemoteIp
50where
51    S: Send + Sync,
52{
53    type Rejection = std::convert::Infallible;
54
55    fn from_request_parts(
56        parts: &mut Parts,
57        _state: &S,
58    ) -> impl std::future::Future<Output = Result<Self, Self::Rejection>> + Send {
59        let ip = parts
60            .extensions
61            .get::<ConnectInfo<SocketAddr>>()
62            .map(|ConnectInfo(addr)| addr.ip());
63        std::future::ready(Ok(RemoteIp(ip)))
64    }
65}
66
67/// Shared state for the standalone Direct Line server.
68#[derive(Clone)]
69pub struct StandaloneState {
70    pub provider: WebChatProvider,
71    pub conversations: SharedConversationStore,
72    pub sessions: SharedSessionStore,
73    pub bus: SharedBus,
74    rate_limiter: Arc<IpRateLimiter>,
75}
76
77impl StandaloneState {
78    /// Builds a new state wired to the provided conversation store.
79    pub async fn with_store(
80        provider: WebChatProvider,
81        conversations: SharedConversationStore,
82        sessions: SharedSessionStore,
83        bus: SharedBus,
84    ) -> anyhow::Result<Self> {
85        let signing_keys = provider
86            .signing_keys()
87            .await
88            .context("failed to resolve Direct Line signing key")?;
89        if let Err(err) = jwt::install_keys(signing_keys) {
90            debug!("jwt keys already installed: {err}");
91        }
92        Ok(Self {
93            provider,
94            conversations,
95            sessions,
96            bus,
97            rate_limiter: Arc::new(IpRateLimiter::new(RATE_LIMIT_CAPACITY, RATE_LIMIT_WINDOW)),
98        })
99    }
100
101    /// Builds a new state with the default in-memory conversation store.
102    pub async fn new(provider: WebChatProvider) -> anyhow::Result<Self> {
103        Self::with_store(
104            provider,
105            memory_store(),
106            Arc::new(MemorySessionStore::default()),
107            Arc::new(NoopBus),
108        )
109        .await
110    }
111}
112
113fn router_blueprint() -> Router {
114    Router::new()
115        .route(
116            "/v3/directline/tokens/generate",
117            post(generate_token_handler),
118        )
119        .route(
120            "/v3/directline/conversations",
121            post(start_conversation_handler),
122        )
123        .route(
124            "/v3/directline/conversations/{id}/activities",
125            get(list_activities_handler).post(post_activity_handler),
126        )
127        .route(
128            "/v3/directline/conversations/{id}/stream",
129            get(conversation_stream_handler),
130        )
131        .route(
132            "/webchat/admin/{env}/{tenant}/post-activity",
133            post(admin_post_activity_handler),
134        )
135}
136
137/// Builds the Axum router that serves `/v3/directline` endpoints.
138pub fn router(state: Arc<StandaloneState>) -> Router {
139    router_blueprint().layer(axum::Extension(state))
140}
141
142#[allow(dead_code)]
143const fn assert_state_bounds<T: Send + Sync + Clone>() {}
144const _: () = {
145    assert_state_bounds::<StandaloneState>();
146};
147
148#[doc(hidden)]
149#[derive(Debug, Deserialize)]
150pub struct TenantQuery {
151    env: String,
152    tenant: String,
153    #[serde(default)]
154    team: Option<String>,
155}
156
157#[doc(hidden)]
158#[derive(Debug, Deserialize)]
159#[serde(rename_all = "camelCase")]
160pub struct GenerateTokenRequest {
161    #[serde(default)]
162    user: Option<UserDescriptor>,
163    #[allow(dead_code)]
164    #[serde(default)]
165    trusted_origins: Option<Vec<String>>,
166}
167
168#[doc(hidden)]
169#[derive(Debug, Deserialize)]
170pub struct UserDescriptor {
171    #[serde(default)]
172    id: Option<String>,
173}
174
175#[derive(Debug, Serialize, Deserialize)]
176struct TokenResponse {
177    token: String,
178    expires_in: u64,
179}
180
181async fn generate_token_handler(
182    Extension(state): Extension<Arc<StandaloneState>>,
183    remote: RemoteIp,
184    Query(query): Query<TenantQuery>,
185    Json(body): Json<GenerateTokenRequest>,
186) -> Result<Json<TokenResponse>, StatusCode> {
187    let ip = remote.0.unwrap_or_else(|| IpAddr::from([127, 0, 0, 1]));
188    if !state.rate_limiter.check(ip) {
189        return Err(StatusCode::TOO_MANY_REQUESTS);
190    }
191
192    let tenant_ctx = tenant_ctx_from_query(&query)?;
193    let subject = body
194        .user
195        .and_then(|user| user.id)
196        .filter(|id| !id.trim().is_empty())
197        .unwrap_or_else(|| format!("user:{}", Uuid::new_v4()));
198
199    let claims = Claims::new(
200        subject,
201        tenant_claims_from_ctx(&tenant_ctx),
202        jwt::ttl(TOKEN_TTL_SECONDS),
203    );
204    let token = jwt::sign(&claims).map_err(|err| map_error("sign token", err))?;
205    Ok(Json(TokenResponse {
206        token,
207        expires_in: TOKEN_TTL_SECONDS,
208    }))
209}
210
211#[doc(hidden)]
212#[derive(Debug, Deserialize)]
213pub struct ConversationPath {
214    id: String,
215}
216
217#[derive(Debug, Serialize, Deserialize)]
218#[serde(rename_all = "camelCase")]
219struct ConversationResponse {
220    conversation_id: String,
221    token: String,
222    #[serde(skip_serializing_if = "Option::is_none")]
223    stream_url: Option<String>,
224}
225
226async fn start_conversation_handler(
227    Extension(state): Extension<Arc<StandaloneState>>,
228    headers: HeaderMap,
229) -> Result<Json<ConversationResponse>, StatusCode> {
230    let token = extract_bearer(&headers)?;
231    let claims = jwt::verify(token).map_err(|_| StatusCode::UNAUTHORIZED)?;
232    if claims.conv.is_some() {
233        return Err(StatusCode::BAD_REQUEST);
234    }
235    let tenant_ctx = tenant_ctx_from_claims(&claims)?;
236    let conversation_id = Uuid::new_v4().to_string();
237    state
238        .conversations
239        .create(&conversation_id, tenant_ctx.clone())
240        .await
241        .map_err(map_store_error)?;
242
243    let conversation_token = Claims::new(
244        claims.sub.clone(),
245        claims.ctx.clone(),
246        jwt::ttl(TOKEN_TTL_SECONDS),
247    )
248    .with_conversation(conversation_id.clone());
249    let token = jwt::sign(&conversation_token).map_err(|err| map_error("sign token", err))?;
250
251    if let Err(err) = state
252        .sessions
253        .upsert(WebchatSession::new(
254            conversation_id.clone(),
255            tenant_ctx.clone(),
256            token.clone(),
257        ))
258        .await
259    {
260        error!(error = %err, "failed to persist webchat session");
261        return Err(StatusCode::INTERNAL_SERVER_ERROR);
262    }
263
264    let stream_url = stream_url_from_headers(&state, &headers, &conversation_id, &token);
265    Ok(Json(ConversationResponse {
266        conversation_id,
267        token,
268        stream_url,
269    }))
270}
271
272#[doc(hidden)]
273#[derive(Debug, Deserialize)]
274pub struct ActivitiesQuery {
275    #[serde(default)]
276    watermark: Option<String>,
277}
278
279#[derive(Debug, Serialize)]
280struct ActivitiesResponse {
281    activities: Vec<Activity>,
282    watermark: String,
283}
284
285async fn list_activities_handler(
286    Extension(state): Extension<Arc<StandaloneState>>,
287    Path(path): Path<ConversationPath>,
288    headers: HeaderMap,
289    Query(query): Query<ActivitiesQuery>,
290) -> Result<Json<ActivitiesResponse>, StatusCode> {
291    let _claims = validate_conversation_token(state.as_ref(), &path.id, &headers).await?;
292    let watermark = query
293        .watermark
294        .as_deref()
295        .and_then(|value| {
296            let trimmed = value.trim();
297            if trimmed.is_empty() {
298                None
299            } else {
300                Some(trimmed)
301            }
302        })
303        .map(parse_watermark)
304        .transpose()?;
305    let page = state
306        .conversations
307        .activities(&path.id, watermark)
308        .await
309        .map_err(map_store_error)?;
310    let activities: Vec<Activity> = page
311        .activities
312        .into_iter()
313        .map(|entry| entry.activity)
314        .collect();
315    Ok(Json(ActivitiesResponse {
316        activities,
317        watermark: page.watermark.to_string(),
318    }))
319}
320
321#[derive(Debug, Serialize)]
322struct ActivityAck {
323    id: String,
324}
325
326async fn post_activity_handler(
327    Extension(state): Extension<Arc<StandaloneState>>,
328    Path(path): Path<ConversationPath>,
329    headers: HeaderMap,
330    Json(payload): Json<serde_json::Value>,
331) -> Result<impl IntoResponse, StatusCode> {
332    let claims = validate_conversation_token(state.as_ref(), &path.id, &headers).await?;
333    let tenant_ctx = state
334        .conversations
335        .tenant_ctx(&path.id)
336        .await
337        .map_err(map_store_error)?;
338
339    let mut bus_activity = payload.clone();
340    apply_user_json_defaults(&mut bus_activity, &path.id, &claims.sub);
341
342    let mut activity: Activity =
343        serde_json::from_value(payload).map_err(|_| StatusCode::BAD_REQUEST)?;
344    apply_user_defaults(&mut activity, &path.id, &claims.sub);
345
346    let ingress = ingress::Ingress::new(state.bus.clone(), state.sessions.clone());
347    if let Err(err) = ingress
348        .publish_incoming(&bus_activity, &tenant_ctx, &path.id)
349        .await
350    {
351        warn!(error = %err, "failed to publish incoming activity");
352    }
353
354    let stored = state
355        .conversations
356        .append(&path.id, activity)
357        .await
358        .map_err(map_store_error)?;
359
360    if let Err(err) = state
361        .sessions
362        .update_watermark(&path.id, Some((stored.watermark + 1).to_string()))
363        .await
364    {
365        warn!(error = %err, "failed to update watermark");
366    }
367    Ok((
368        StatusCode::CREATED,
369        Json(ActivityAck {
370            id: stored.activity.id,
371        }),
372    ))
373}
374
375#[derive(Debug, Deserialize)]
376pub struct AdminPath {
377    pub env: String,
378    pub tenant: String,
379}
380
381#[derive(Debug, Deserialize)]
382pub struct AdminPostActivityRequest {
383    #[serde(default)]
384    pub team: Option<String>,
385    #[serde(rename = "conversation_id", default)]
386    pub conversation_id: Option<String>,
387    pub activity: serde_json::Value,
388}
389
390#[derive(Debug, Serialize)]
391pub struct AdminPostActivityResponse {
392    pub posted: usize,
393    pub skipped: usize,
394}
395
396async fn admin_post_activity_handler(
397    Extension(state): Extension<Arc<StandaloneState>>,
398    Path(path): Path<AdminPath>,
399    Json(body): Json<AdminPostActivityRequest>,
400) -> Result<Json<AdminPostActivityResponse>, StatusCode> {
401    let AdminPostActivityRequest {
402        team,
403        conversation_id,
404        activity,
405    } = body;
406
407    if !activity.is_object() {
408        return Err(StatusCode::BAD_REQUEST);
409    }
410
411    let base_activity: Activity =
412        serde_json::from_value(activity).map_err(|_| StatusCode::BAD_REQUEST)?;
413    let team_filter = team.as_deref();
414
415    if let Some(conversation_id) = conversation_id {
416        let session = state
417            .sessions
418            .get(&conversation_id)
419            .await
420            .map_err(|err| map_error("load session", err))?
421            .ok_or(StatusCode::NOT_FOUND)?;
422
423        if !session
424            .tenant_ctx
425            .env
426            .as_ref()
427            .eq_ignore_ascii_case(&path.env)
428            || !session
429                .tenant_ctx
430                .tenant
431                .as_ref()
432                .eq_ignore_ascii_case(&path.tenant)
433        {
434            return Err(StatusCode::NOT_FOUND);
435        }
436
437        if let Some(team) = team_filter
438            && !session
439                .tenant_ctx
440                .team
441                .as_ref()
442                .map(|value| value.as_ref().eq_ignore_ascii_case(team))
443                .unwrap_or(false)
444        {
445            return Err(StatusCode::NOT_FOUND);
446        }
447
448        if !session.proactive_ok {
449            return Err(StatusCode::BAD_REQUEST);
450        }
451
452        append_bot_activity(state.as_ref(), &conversation_id, &base_activity).await?;
453
454        return Ok(Json(AdminPostActivityResponse {
455            posted: 1,
456            skipped: 0,
457        }));
458    }
459
460    let sessions = state
461        .sessions
462        .list_by_tenant(&path.env, &path.tenant, team_filter)
463        .await
464        .map_err(|err| map_error("list sessions", err))?;
465
466    if sessions.is_empty() {
467        return Err(StatusCode::NOT_FOUND);
468    }
469
470    let mut posted = 0usize;
471    let mut skipped = 0usize;
472    for session in sessions {
473        if !session.proactive_ok {
474            skipped += 1;
475            continue;
476        }
477
478        match append_bot_activity(state.as_ref(), &session.conversation_id, &base_activity).await {
479            Ok(()) => posted += 1,
480            Err(StatusCode::NOT_FOUND) => {
481                skipped += 1;
482                warn!(
483                    conversation = %session.conversation_id,
484                    "conversation not found while appending activity"
485                );
486            }
487            Err(code) => return Err(code),
488        }
489    }
490
491    if posted == 0 {
492        return Err(StatusCode::NOT_FOUND);
493    }
494
495    Ok(Json(AdminPostActivityResponse { posted, skipped }))
496}
497
498#[derive(Debug, Deserialize)]
499struct StreamQuery {
500    t: String,
501    #[serde(default)]
502    watermark: Option<String>,
503}
504
505async fn conversation_stream_handler(
506    Extension(state): Extension<Arc<StandaloneState>>,
507    Path(path): Path<ConversationPath>,
508    Query(query): Query<StreamQuery>,
509    ws: WebSocketUpgrade,
510) -> Result<impl IntoResponse, StatusCode> {
511    let claims = jwt::verify(&query.t).map_err(|_| StatusCode::UNAUTHORIZED)?;
512    ensure_conversation_access(state.as_ref(), &path.id, &claims).await?;
513    let watermark = query
514        .watermark
515        .as_deref()
516        .map(parse_watermark)
517        .transpose()?;
518    Ok(ws.on_upgrade(move |socket| async move {
519        if let Err(err) = run_websocket(socket, state, path.id, watermark).await {
520            warn!("websocket closed with error: {err:?}");
521        }
522    }))
523}
524
525async fn run_websocket(
526    mut socket: WebSocket,
527    state: Arc<StandaloneState>,
528    conversation_id: String,
529    watermark: Option<u64>,
530) -> anyhow::Result<()> {
531    let page = state
532        .conversations
533        .activities(&conversation_id, watermark)
534        .await
535        .map_err(|err| anyhow::anyhow!(err.to_string()))?;
536    let initial: Vec<StoredActivity> = page.activities.iter().cloned().collect();
537    if !initial.is_empty() {
538        send_envelope(&mut socket, &initial, page.watermark).await?;
539    }
540    let mut consecutive_failures: u32 = 0;
541    const SEND_FAILURE_THRESHOLD: u32 = 5;
542    let mut current_watermark = page.watermark;
543    let mut subscriber = state
544        .conversations
545        .subscribe(&conversation_id)
546        .await
547        .map_err(|err| anyhow::anyhow!(err.to_string()))?;
548
549    loop {
550        tokio::select! {
551            message = socket.recv() => {
552                match message {
553                    Some(Ok(Message::Close(_))) | None => break,
554                    Some(Ok(_)) => continue,
555                    Some(Err(err)) => {
556                        warn!("websocket recv error: {err}");
557                        break;
558                    }
559                }
560            }
561            received = subscriber.recv() => {
562                match received {
563                    Ok(activity) => {
564                        if let Err(err) = send_envelope(&mut socket, std::slice::from_ref(&activity), activity.watermark + 1).await {
565                            consecutive_failures = consecutive_failures.saturating_add(1);
566                            warn!(error = ?err, consecutive_failures, "websocket send error");
567                            if consecutive_failures >= SEND_FAILURE_THRESHOLD {
568                                warn!("terminating websocket due to repeated send failures");
569                                break;
570                            }
571                            continue;
572                        } else {
573                            consecutive_failures = 0;
574                        }
575                        current_watermark = activity.watermark + 1;
576                    }
577                    Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
578                    Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => {
579                        let page = state.conversations.activities(&conversation_id, Some(current_watermark))
580                            .await
581                            .map_err(|err| anyhow::anyhow!(err.to_string()))?;
582                        let resend: Vec<StoredActivity> = page.activities.iter().cloned().collect();
583                        if !resend.is_empty() {
584                            if let Err(err) = send_envelope(&mut socket, &resend, page.watermark).await {
585                                consecutive_failures = consecutive_failures.saturating_add(1);
586                                warn!(error = ?err, consecutive_failures, "websocket resend error");
587                                if consecutive_failures >= SEND_FAILURE_THRESHOLD {
588                                    warn!("terminating websocket due to repeated resend failures");
589                                    break;
590                                }
591                                continue;
592                            } else {
593                                consecutive_failures = 0;
594                            }
595                        }
596                        current_watermark = page.watermark;
597                    }
598                }
599            }
600        }
601    }
602    Ok(())
603}
604
605#[doc(hidden)]
606pub async fn test_run_websocket_public(
607    socket: WebSocket,
608    state: Arc<StandaloneState>,
609    conversation_id: String,
610    watermark: Option<u64>,
611) -> anyhow::Result<()> {
612    run_websocket(socket, state, conversation_id, watermark).await
613}
614
615async fn send_envelope(
616    socket: &mut WebSocket,
617    activities: &[StoredActivity],
618    watermark: u64,
619) -> anyhow::Result<()> {
620    if activities.is_empty() {
621        return Ok(());
622    }
623    let payload = envelope_payload(activities, watermark)?;
624    socket.send(Message::Text(payload.into())).await?;
625    Ok(())
626}
627
628fn envelope_payload(activities: &[StoredActivity], watermark: u64) -> anyhow::Result<String> {
629    let acts: Vec<Activity> = activities
630        .iter()
631        .map(|entry| entry.activity.clone())
632        .collect();
633    Ok(serde_json::to_string(&serde_json::json!({
634        "activities": acts,
635        "watermark": watermark.to_string(),
636    }))?)
637}
638
639fn tenant_ctx_from_query(query: &TenantQuery) -> Result<TenantCtx, StatusCode> {
640    let env = query.env.trim();
641    let tenant = query.tenant.trim();
642    if env.is_empty() || tenant.is_empty() {
643        return Err(StatusCode::BAD_REQUEST);
644    }
645    let mut ctx = TenantCtx::new(EnvId(env.to_string()), TenantId(tenant.to_string()));
646    if let Some(team) = &query.team {
647        let team = team.trim();
648        if !team.is_empty() {
649            ctx = ctx.with_team(Some(TeamId(team.to_string())));
650        }
651    }
652    Ok(ctx)
653}
654
655fn tenant_ctx_from_claims(claims: &Claims) -> Result<TenantCtx, StatusCode> {
656    let mut ctx = TenantCtx::new(
657        EnvId(claims.ctx.env.clone()),
658        TenantId(claims.ctx.tenant.clone()),
659    );
660    if let Some(team) = &claims.ctx.team {
661        ctx = ctx.with_team(Some(TeamId(team.clone())));
662    }
663    Ok(ctx)
664}
665
666fn tenant_claims_from_ctx(ctx: &TenantCtx) -> TenantClaims {
667    TenantClaims {
668        env: ctx.env.as_ref().to_string(),
669        tenant: ctx.tenant.as_ref().to_string(),
670        team: ctx.team.as_ref().map(|team| team.as_ref().to_string()),
671    }
672}
673
674fn extract_bearer(headers: &HeaderMap) -> Result<&str, StatusCode> {
675    let value = headers
676        .get(axum::http::header::AUTHORIZATION)
677        .and_then(|value| value.to_str().ok())
678        .ok_or(StatusCode::UNAUTHORIZED)?;
679    if let Some(rest) = value.strip_prefix("Bearer ") {
680        Ok(rest.trim())
681    } else if let Some(rest) = value.strip_prefix("bearer ") {
682        Ok(rest.trim())
683    } else {
684        Err(StatusCode::UNAUTHORIZED)
685    }
686}
687
688fn parse_watermark(value: &str) -> Result<u64, StatusCode> {
689    value.parse::<u64>().map_err(|_| StatusCode::BAD_REQUEST)
690}
691
692fn stream_url_from_headers(
693    state: &StandaloneState,
694    headers: &HeaderMap,
695    conversation_id: &str,
696    token: &str,
697) -> Option<String> {
698    if let Some(host) = headers
699        .get(axum::http::header::HOST)
700        .and_then(|value| value.to_str().ok())
701    {
702        let scheme = if state
703            .provider
704            .config()
705            .direct_line_base()
706            .starts_with("https://")
707        {
708            "wss"
709        } else {
710            "ws"
711        };
712        Some(format!(
713            "{scheme}://{host}/v3/directline/conversations/{conversation_id}/stream?t={token}"
714        ))
715    } else {
716        None
717    }
718}
719
720async fn validate_conversation_token(
721    state: &StandaloneState,
722    conversation_id: &str,
723    headers: &HeaderMap,
724) -> Result<Claims, StatusCode> {
725    let token = extract_bearer(headers)?;
726    let claims = jwt::verify(token).map_err(|_| StatusCode::UNAUTHORIZED)?;
727    ensure_conversation_access(state, conversation_id, &claims).await?;
728    Ok(claims)
729}
730
731async fn ensure_conversation_access(
732    state: &StandaloneState,
733    conversation_id: &str,
734    claims: &Claims,
735) -> Result<(), StatusCode> {
736    if !claims.has_conversation(conversation_id) {
737        return Err(StatusCode::FORBIDDEN);
738    }
739    let claimed_ctx = tenant_ctx_from_claims(claims)?;
740    let stored_ctx = state
741        .conversations
742        .tenant_ctx(conversation_id)
743        .await
744        .map_err(map_store_error)?;
745    if stored_ctx != claimed_ctx {
746        return Err(StatusCode::FORBIDDEN);
747    }
748    Ok(())
749}
750
751fn normalise_activity(activity: &mut Activity, conversation_id: &str, subject: &str) {
752    if activity
753        .from
754        .as_ref()
755        .map(|from| from.id.trim().is_empty())
756        .unwrap_or(true)
757    {
758        activity.from = Some(ChannelAccount {
759            id: subject.to_string(),
760            name: None,
761            role: Some("user".into()),
762        });
763    }
764    if activity.conversation.is_none() {
765        activity.conversation = Some(ConversationAccount {
766            id: conversation_id.to_string(),
767        });
768    }
769}
770
771fn apply_bot_defaults(activity: &mut Activity, conversation_id: &str) {
772    normalise_activity(activity, conversation_id, "bot");
773    if let Some(from) = activity.from.as_mut() {
774        if from.id.trim().is_empty() {
775            from.id = "bot".into();
776        }
777        from.role = Some("bot".into());
778    } else {
779        activity.from = Some(ChannelAccount {
780            id: "bot".into(),
781            name: None,
782            role: Some("bot".into()),
783        });
784    }
785}
786
787fn apply_user_json_defaults(
788    activity: &mut serde_json::Value,
789    conversation_id: &str,
790    subject: &str,
791) {
792    let Some(obj) = activity.as_object_mut() else {
793        return;
794    };
795
796    let from_entry = obj
797        .entry("from".to_string())
798        .or_insert_with(|| serde_json::Value::Object(serde_json::Map::new()));
799    if !from_entry.is_object() {
800        *from_entry = serde_json::Value::Object(serde_json::Map::new());
801    }
802    if let Some(from_obj) = from_entry.as_object_mut() {
803        let id_entry = from_obj
804            .entry("id".to_string())
805            .or_insert_with(|| serde_json::Value::String(subject.to_string()));
806        if id_entry
807            .as_str()
808            .map(|value| value.trim().is_empty())
809            .unwrap_or(true)
810        {
811            *id_entry = serde_json::Value::String(subject.to_string());
812        }
813        let role_entry = from_obj
814            .entry("role".to_string())
815            .or_insert_with(|| serde_json::Value::String("user".into()));
816        if !role_entry
817            .as_str()
818            .map(|role| role.eq_ignore_ascii_case("user"))
819            .unwrap_or(false)
820        {
821            *role_entry = serde_json::Value::String("user".into());
822        }
823    }
824
825    let conversation_entry = obj.entry("conversation".to_string()).or_insert_with(|| {
826        let mut conv = serde_json::Map::new();
827        conv.insert(
828            "id".to_string(),
829            serde_json::Value::String(conversation_id.to_string()),
830        );
831        serde_json::Value::Object(conv)
832    });
833    if let Some(conv_obj) = conversation_entry.as_object_mut() {
834        let id_entry = conv_obj
835            .entry("id".to_string())
836            .or_insert_with(|| serde_json::Value::String(conversation_id.to_string()));
837        if id_entry
838            .as_str()
839            .map(|value| value.trim().is_empty())
840            .unwrap_or(true)
841        {
842            *id_entry = serde_json::Value::String(conversation_id.to_string());
843        }
844    } else {
845        let mut conv = serde_json::Map::new();
846        conv.insert(
847            "id".to_string(),
848            serde_json::Value::String(conversation_id.to_string()),
849        );
850        *conversation_entry = serde_json::Value::Object(conv);
851    }
852}
853
854fn apply_user_defaults(activity: &mut Activity, conversation_id: &str, subject: &str) {
855    normalise_activity(activity, conversation_id, subject);
856    if let Some(from) = activity.from.as_mut() {
857        if from
858            .role
859            .as_deref()
860            .map(|role| role.eq_ignore_ascii_case("user"))
861            .unwrap_or(true)
862        {
863            from.role = Some("user".into());
864        }
865        if from.id.trim().is_empty() {
866            from.id = subject.to_string();
867        }
868    } else {
869        activity.from = Some(ChannelAccount {
870            id: subject.to_string(),
871            name: None,
872            role: Some("user".into()),
873        });
874    }
875}
876
877async fn append_bot_activity(
878    state: &StandaloneState,
879    conversation_id: &str,
880    base_activity: &Activity,
881) -> Result<(), StatusCode> {
882    let mut activity = base_activity.clone();
883    apply_bot_defaults(&mut activity, conversation_id);
884    let stored = state
885        .conversations
886        .append(conversation_id, activity)
887        .await
888        .map_err(map_store_error)?;
889
890    if let Err(err) = state
891        .sessions
892        .update_watermark(conversation_id, Some((stored.watermark + 1).to_string()))
893        .await
894    {
895        warn!(
896            error = %err,
897            conversation = %conversation_id,
898            "failed to update watermark"
899        );
900    }
901
902    Ok(())
903}
904
905fn map_store_error(err: super::conversation::StoreError) -> StatusCode {
906    match err {
907        super::conversation::StoreError::AlreadyExists(_) => StatusCode::CONFLICT,
908        super::conversation::StoreError::NotFound(_) => StatusCode::NOT_FOUND,
909        super::conversation::StoreError::QuotaExceeded(_) => StatusCode::TOO_MANY_REQUESTS,
910        super::conversation::StoreError::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
911    }
912}
913
914fn map_error(context: &str, err: anyhow::Error) -> StatusCode {
915    error!("{context}: {err}");
916    StatusCode::INTERNAL_SERVER_ERROR
917}
918
919struct IpRateLimiter {
920    capacity: f64,
921    refill_per_sec: f64,
922    buckets: Mutex<HashMap<IpAddr, TokenBucket>>,
923}
924
925struct TokenBucket {
926    tokens: f64,
927    last: Instant,
928}
929
930impl IpRateLimiter {
931    fn new(capacity: usize, window: Duration) -> Self {
932        let capacity = capacity as f64;
933        let refill_per_sec = if window.is_zero() {
934            capacity
935        } else {
936            capacity / window.as_secs_f64()
937        };
938        Self {
939            capacity,
940            refill_per_sec,
941            buckets: Mutex::new(HashMap::new()),
942        }
943    }
944
945    fn check(&self, ip: IpAddr) -> bool {
946        let mut guard = self.buckets.lock().expect("rate limiter mutex poisoned");
947        let entry = guard.entry(ip).or_insert(TokenBucket {
948            tokens: self.capacity,
949            last: Instant::now(),
950        });
951        let now = Instant::now();
952        let elapsed = now.saturating_duration_since(entry.last).as_secs_f64();
953        if elapsed > 0.0 {
954            entry.tokens = (entry.tokens + elapsed * self.refill_per_sec).min(self.capacity);
955            entry.last = now;
956        }
957        if entry.tokens >= 1.0 {
958            entry.tokens -= 1.0;
959            true
960        } else {
961            false
962        }
963    }
964}
965
966#[cfg(test)]
967mod tests {
968    use super::*;
969
970    use crate::platforms::webchat::{
971        WebChatProvider,
972        bus::{EventBus, Subject},
973        config::Config,
974        session::WebchatSessionStore,
975        types::{GreenticEvent, MessagePayload},
976    };
977    use axum::{
978        Extension, Json,
979        extract::{Path, Query},
980        http::{self, HeaderMap, HeaderValue, StatusCode},
981        response::IntoResponse,
982    };
983    use greentic_secrets_spec::record_from_plain;
984    use secrets_core::{Scope, SecretUri, SecretsBackend, VersionedSecret};
985    use serde_json::json;
986    use tokio::sync::Mutex;
987
988    #[derive(Clone)]
989    struct TestSecretsBackend {
990        secret: String,
991    }
992
993    impl TestSecretsBackend {
994        fn new(secret: String) -> Self {
995            Self { secret }
996        }
997    }
998
999    impl SecretsBackend for TestSecretsBackend {
1000        fn put(
1001            &self,
1002            _record: secrets_core::SecretRecord,
1003        ) -> secrets_core::Result<secrets_core::SecretVersion> {
1004            Err(secrets_core::Error::Backend(
1005                "test backend is read-only".into(),
1006            ))
1007        }
1008
1009        fn get(
1010            &self,
1011            uri: &SecretUri,
1012            _version: Option<u64>,
1013        ) -> secrets_core::Result<Option<VersionedSecret>> {
1014            if uri.category() == "webchat" && uri.name() == "jwt_signing_key" {
1015                let record = record_from_plain(self.secret.clone());
1016                Ok(Some(VersionedSecret {
1017                    version: 1,
1018                    deleted: false,
1019                    record: Some(record),
1020                }))
1021            } else {
1022                Ok(None)
1023            }
1024        }
1025
1026        fn list(
1027            &self,
1028            _scope: &secrets_core::Scope,
1029            _category_prefix: Option<&str>,
1030            _name_prefix: Option<&str>,
1031        ) -> secrets_core::Result<Vec<secrets_core::SecretListItem>> {
1032            let uri = SecretUri::new(
1033                Scope::new("global", "webchat", None).expect("scope"),
1034                "webchat".to_string(),
1035                "jwt_signing_key".to_string(),
1036            )
1037            .expect("uri");
1038            Ok(vec![secrets_core::SecretListItem {
1039                uri,
1040                visibility: secrets_core::Visibility::Tenant,
1041                latest_version: Some("1".into()),
1042                content_type: secrets_core::ContentType::Opaque,
1043            }])
1044        }
1045
1046        fn delete(&self, _uri: &SecretUri) -> secrets_core::Result<secrets_core::SecretVersion> {
1047            Err(secrets_core::Error::Backend(
1048                "test backend is read-only".into(),
1049            ))
1050        }
1051
1052        fn versions(
1053            &self,
1054            _uri: &SecretUri,
1055        ) -> secrets_core::Result<Vec<secrets_core::SecretVersion>> {
1056            Ok(vec![secrets_core::SecretVersion {
1057                version: 1,
1058                deleted: false,
1059            }])
1060        }
1061
1062        fn exists(&self, uri: &SecretUri) -> secrets_core::Result<bool> {
1063            Ok(uri.category() == "webchat" && uri.name() == "jwt_signing_key")
1064        }
1065    }
1066
1067    fn test_provider(base_url: &str) -> WebChatProvider {
1068        let backend = Arc::new(TestSecretsBackend::new("test-signing-key".to_string()));
1069        let scope = Scope::new("global", "webchat", None).expect("valid signing scope");
1070        WebChatProvider::new(Config::with_base_url(base_url), backend).with_signing_scope(scope)
1071    }
1072
1073    #[tokio::test]
1074    async fn user_activity_is_normalized_and_streamed() {
1075        let bus = Arc::new(RecordingBus::default());
1076        let sessions = Arc::new(MemorySessionStore::default());
1077        let conversations = memory_store();
1078        let provider = test_provider("http://localhost");
1079        let state = Arc::new(
1080            StandaloneState::with_store(
1081                provider,
1082                conversations.clone(),
1083                sessions.clone(),
1084                bus.clone(),
1085            )
1086            .await
1087            .expect("state"),
1088        );
1089
1090        let user_token = issue_token(&state, "user-1").await;
1091        let (conversation_id, conversation_token) = start_conversation(&state, &user_token).await;
1092
1093        let mut subscriber = conversations
1094            .subscribe(&conversation_id)
1095            .await
1096            .expect("subscribe");
1097
1098        post_user_activity(
1099            &state,
1100            &conversation_id,
1101            &conversation_token,
1102            json!({
1103                "type": "message",
1104                "text": "hello"
1105            }),
1106        )
1107        .await;
1108
1109        let stored = subscriber.recv().await.expect("activity");
1110        assert_eq!(stored.activity.text.as_deref(), Some("hello"));
1111
1112        let events = bus.take().await;
1113        assert_eq!(events.len(), 1);
1114        match &events[0] {
1115            GreenticEvent::IncomingMessage(msg) => match &msg.payload {
1116                MessagePayload::Text { text, .. } => assert_eq!(text, "hello"),
1117                other => panic!("unexpected payload: {other:?}"),
1118            },
1119        }
1120    }
1121
1122    #[tokio::test]
1123    async fn invoke_is_normalized_to_event() {
1124        let bus = Arc::new(RecordingBus::default());
1125        let sessions = Arc::new(MemorySessionStore::default());
1126        let conversations = memory_store();
1127        let provider = test_provider("http://localhost");
1128        let state = Arc::new(
1129            StandaloneState::with_store(provider, conversations.clone(), sessions, bus.clone())
1130                .await
1131                .expect("state"),
1132        );
1133        let user_token = issue_token(&state, "user-2").await;
1134        let (conversation_id, conversation_token) = start_conversation(&state, &user_token).await;
1135
1136        post_user_activity(
1137            &state,
1138            &conversation_id,
1139            &conversation_token,
1140            json!({
1141                "type": "invoke",
1142                "name": "adaptiveCard/action",
1143                "value": {"foo": "bar"}
1144            }),
1145        )
1146        .await;
1147
1148        let events = bus.take().await;
1149        assert_eq!(events.len(), 1);
1150        match &events[0] {
1151            GreenticEvent::IncomingMessage(msg) => match &msg.payload {
1152                MessagePayload::Event { name, .. } => assert_eq!(name, "adaptiveCard/action"),
1153                other => panic!("expected event payload, got {other:?}"),
1154            },
1155        }
1156    }
1157
1158    #[tokio::test]
1159    async fn admin_bot_activity_appends_and_streams() {
1160        let bus = Arc::new(RecordingBus::default());
1161        let sessions = Arc::new(MemorySessionStore::default());
1162        let conversations = memory_store();
1163        let provider = test_provider("http://localhost");
1164        let state = Arc::new(
1165            StandaloneState::with_store(
1166                provider,
1167                conversations.clone(),
1168                sessions.clone(),
1169                bus.clone(),
1170            )
1171            .await
1172            .expect("state"),
1173        );
1174
1175        let user_token = issue_token(&state, "user-3").await;
1176        let (conversation_id, _) = start_conversation(&state, &user_token).await;
1177        let mut subscriber = conversations
1178            .subscribe(&conversation_id)
1179            .await
1180            .expect("subscribe");
1181
1182        let Json(response) = admin_post_activity_handler(
1183            Extension(Arc::clone(&state)),
1184            Path(AdminPath {
1185                env: "dev".to_string(),
1186                tenant: "acme".to_string(),
1187            }),
1188            Json(AdminPostActivityRequest {
1189                team: None,
1190                conversation_id: Some(conversation_id.clone()),
1191                activity: json!({
1192                    "type": "message",
1193                    "text": "bot says hi"
1194                }),
1195            }),
1196        )
1197        .await
1198        .expect("admin post");
1199
1200        assert_eq!(response.posted, 1);
1201        assert_eq!(response.skipped, 0);
1202
1203        let stored = subscriber.recv().await.expect("activity");
1204        assert_eq!(stored.activity.text.as_deref(), Some("bot says hi"));
1205        assert!(
1206            stored
1207                .activity
1208                .from
1209                .as_ref()
1210                .and_then(|from| from.role.as_deref())
1211                .map(|role| role.eq_ignore_ascii_case("bot"))
1212                .unwrap_or(false)
1213        );
1214
1215        let session = sessions
1216            .get(&conversation_id)
1217            .await
1218            .expect("session fetch")
1219            .expect("session exists");
1220        assert_eq!(session.watermark.as_deref(), Some("1"));
1221
1222        assert!(bus.take().await.is_empty());
1223    }
1224
1225    async fn issue_token(state: &Arc<StandaloneState>, user_id: &str) -> String {
1226        let query = TenantQuery {
1227            env: "dev".to_string(),
1228            tenant: "acme".to_string(),
1229            team: None,
1230        };
1231        let body = GenerateTokenRequest {
1232            user: Some(UserDescriptor {
1233                id: Some(user_id.to_string()),
1234            }),
1235            trusted_origins: None,
1236        };
1237        let Json(response) = generate_token_handler(
1238            Extension(Arc::clone(state)),
1239            RemoteIp(None),
1240            Query(query),
1241            Json(body),
1242        )
1243        .await
1244        .expect("generate token");
1245        response.token
1246    }
1247
1248    async fn start_conversation(state: &Arc<StandaloneState>, token: &str) -> (String, String) {
1249        let mut headers = HeaderMap::new();
1250        headers.insert(
1251            http::header::AUTHORIZATION,
1252            HeaderValue::from_str(&format!("Bearer {token}")).unwrap(),
1253        );
1254        let Json(response) = start_conversation_handler(Extension(Arc::clone(state)), headers)
1255            .await
1256            .expect("start conversation");
1257        (response.conversation_id, response.token)
1258    }
1259
1260    async fn post_user_activity(
1261        state: &Arc<StandaloneState>,
1262        conversation_id: &str,
1263        token: &str,
1264        body: serde_json::Value,
1265    ) {
1266        let mut headers = HeaderMap::new();
1267        headers.insert(
1268            http::header::AUTHORIZATION,
1269            HeaderValue::from_str(&format!("Bearer {token}")).unwrap(),
1270        );
1271        headers.insert(
1272            http::header::CONTENT_TYPE,
1273            HeaderValue::from_static("application/json"),
1274        );
1275        let response = post_activity_handler(
1276            Extension(Arc::clone(state)),
1277            Path(ConversationPath {
1278                id: conversation_id.to_string(),
1279            }),
1280            headers,
1281            Json(body),
1282        )
1283        .await
1284        .expect("post activity");
1285        let response = response.into_response();
1286        assert_eq!(response.status(), StatusCode::CREATED);
1287    }
1288
1289    #[derive(Default)]
1290    struct RecordingBus {
1291        events: Mutex<Vec<GreenticEvent>>,
1292    }
1293
1294    impl RecordingBus {
1295        async fn take(&self) -> Vec<GreenticEvent> {
1296            let mut guard = self.events.lock().await;
1297            std::mem::take(&mut *guard)
1298        }
1299    }
1300
1301    #[async_trait::async_trait]
1302    impl EventBus for RecordingBus {
1303        async fn publish(&self, _subject: &Subject, event: &GreenticEvent) -> anyhow::Result<()> {
1304            self.events.lock().await.push(event.clone());
1305            Ok(())
1306        }
1307    }
1308}