Skip to main content

daemon/grpc_local_impl/
signal.rs

1// SPDX-License-Identifier: Apache-2.0
2//! Local-mode `SignalService`. Reads risk signals attached to states by W1
3//! (`State.risk_signals`) and reports per-signal fire rates over a rolling
4//! window. The actual signal *computation* lands in R3 (`crates/state_review`);
5//! this service exposes whatever's already on disk.
6
7use std::{collections::HashMap, pin::Pin};
8
9use futures::Stream;
10use grpc::heddle::v1::{
11    ComputeStateSignalsRequest, ComputeStateSignalsResponse, GetRepoSignalHealthRequest,
12    PathSymbolRef, RepoSignalHealthReport, RiskSignal as ProtoRiskSignal,
13    SignalAnchor as ProtoSignalAnchor, SignalHealthEntry, SignalUpdate,
14    SubscribeSignalUpdatesRequest, signal_service_server::SignalService,
15};
16use objects::{
17    object::{ChangeId, RiskSignal, RiskSignalBlob, State},
18    store::ObjectStore,
19};
20use repo::Repository;
21use tokio_stream::wrappers::ReceiverStream;
22use tonic::{Request, Response, Status};
23
24use super::{GrpcLocalService, to_status};
25
26#[derive(Clone)]
27pub struct LocalSignalService {
28    inner: GrpcLocalService,
29}
30
31impl LocalSignalService {
32    pub fn new(inner: GrpcLocalService) -> Self {
33        Self { inner }
34    }
35
36    fn repo(&self) -> &Repository {
37        self.inner.repo()
38    }
39}
40
41#[tonic::async_trait]
42impl SignalService for LocalSignalService {
43    type SubscribeSignalUpdatesStream =
44        Pin<Box<dyn Stream<Item = Result<SignalUpdate, Status>> + Send>>;
45
46    async fn compute_state_signals(
47        &self,
48        request: Request<ComputeStateSignalsRequest>,
49    ) -> Result<Response<ComputeStateSignalsResponse>, Status> {
50        let req = request.into_inner();
51        if req.state_id.is_empty() {
52            return Err(Status::invalid_argument("state_id is required"));
53        }
54        let change_id = ChangeId::try_from_slice(&req.state_id)
55            .map_err(|err| Status::invalid_argument(format!("invalid state_id: {err}")))?;
56        let state = self
57            .repo()
58            .store()
59            .get_state(&change_id)
60            .map_err(to_status)?
61            .ok_or_else(|| Status::not_found(format!("state {change_id} not found")))?;
62        // Real R3 computation lives in `crates/state_review`. For W2 this
63        // service surfaces whatever signals already live on the state via
64        // `state.risk_signals`. When R3 wires through, replace this with
65        // a call into `state_review::registry::run_all`.
66        let signals = load_signals(self.repo(), &state)?;
67        let proto_signals = signals
68            .iter()
69            .map(|s| signal_to_proto(s, "visible"))
70            .collect();
71        Ok(Response::new(ComputeStateSignalsResponse {
72            signals: proto_signals,
73            tick_budget: 3,
74        }))
75    }
76
77    async fn get_repo_signal_health(
78        &self,
79        request: Request<GetRepoSignalHealthRequest>,
80    ) -> Result<Response<RepoSignalHealthReport>, Status> {
81        let req = request.into_inner();
82        let window = if req.window_states == 0 {
83            DEFAULT_HEALTH_WINDOW
84        } else {
85            req.window_states.min(MAX_HEALTH_WINDOW) as usize
86        };
87        // Walk recent states from HEAD's first-parent chain, up to `window`.
88        // For each, collect any `RiskSignal`s on disk. Aggregate per
89        // module_id; fire_rate = states-with-signals / states-considered.
90        let states = walk_recent_states(self.repo(), window).map_err(to_status)?;
91        let visited = states.len() as u32;
92        let mut per_module: HashMap<String, u32> = HashMap::new();
93        for state in &states {
94            let signals = load_signals(self.repo(), state)?;
95            // One state contributes at most once per module so a noisy
96            // module isn't double-counted by firing many signals on the
97            // same state.
98            let mut seen_modules = std::collections::HashSet::new();
99            for sig in &signals {
100                let key = sig.producer.module.clone();
101                if seen_modules.insert(key.clone()) {
102                    *per_module.entry(key).or_insert(0) += 1;
103                }
104            }
105        }
106        let warn_threshold = 0.5_f32;
107        let entries = per_module
108            .into_iter()
109            .map(|(module_id, hit_count)| {
110                let fire_rate = if visited == 0 {
111                    0.0
112                } else {
113                    hit_count as f32 / visited as f32
114                };
115                SignalHealthEntry {
116                    module_id,
117                    fire_rate,
118                    warn: fire_rate > warn_threshold,
119                }
120            })
121            .collect();
122        Ok(Response::new(RepoSignalHealthReport {
123            entries,
124            window_states: visited,
125        }))
126    }
127
128    async fn subscribe_signal_updates(
129        &self,
130        _request: Request<SubscribeSignalUpdatesRequest>,
131    ) -> Result<Response<Self::SubscribeSignalUpdatesStream>, Status> {
132        // W2 lands the contract; live event broadcasting wires up in R3 once
133        // the signal registry can recompute on capture. For now we open a
134        // channel that closes immediately — clients see EOF and reconnect
135        // when the producer becomes available.
136        let (_tx, rx) = tokio::sync::mpsc::channel::<Result<SignalUpdate, Status>>(1);
137        Ok(Response::new(Box::pin(ReceiverStream::new(rx))))
138    }
139}
140
141const DEFAULT_HEALTH_WINDOW: usize = 200;
142const MAX_HEALTH_WINDOW: u32 = 5_000;
143
144fn load_signals(repo: &Repository, state: &State) -> Result<Vec<RiskSignal>, Status> {
145    let Some(hash) = state.risk_signals else {
146        return Ok(Vec::new());
147    };
148    let blob = repo
149        .store()
150        .get_blob(&hash)
151        .map_err(to_status)?
152        .ok_or_else(|| {
153            Status::data_loss(format!(
154                "risk_signals blob {hash} referenced by state {} is missing",
155                state.change_id
156            ))
157        })?;
158    let parsed = RiskSignalBlob::decode(blob.content()).map_err(|err| {
159        Status::internal(format!(
160            "failed to decode risk signals on state {}: {err}",
161            state.change_id
162        ))
163    })?;
164    Ok(parsed.signals)
165}
166
167fn walk_recent_states(repo: &Repository, window: usize) -> objects::error::Result<Vec<State>> {
168    let mut out = Vec::new();
169    let mut cursor = repo.head()?;
170    while let Some(id) = cursor {
171        if out.len() >= window {
172            break;
173        }
174        let Some(state) = repo.store().get_state(&id)? else {
175            break;
176        };
177        let parent = state.parents.first().copied();
178        out.push(state);
179        cursor = parent;
180    }
181    Ok(out)
182}
183
184fn signal_to_proto(sig: &RiskSignal, visibility: &str) -> ProtoRiskSignal {
185    let (start_line, end_line) = sig.anchor.line_range.unwrap_or((0, 0));
186    ProtoRiskSignal {
187        kind: sig.kind.as_str().to_string(),
188        anchor: Some(ProtoSignalAnchor {
189            file: sig.anchor.file.clone(),
190            symbol: sig.anchor.symbol.clone().unwrap_or_default(),
191            start_line,
192            end_line,
193        }),
194        reason: sig.reason.clone(),
195        producer_module: sig.producer.module.clone(),
196        producer_version: sig.producer.version,
197        computed_at: Some(prost_types::Timestamp {
198            seconds: sig.computed_at,
199            nanos: 0,
200        }),
201        visibility: visibility.to_string(),
202    }
203}
204
205// Small helper kept private; exported via PathSymbolRef wherever needed.
206#[allow(dead_code)]
207fn make_path_symbol(file: &str, symbol: &str) -> PathSymbolRef {
208    PathSymbolRef {
209        file: file.to_string(),
210        symbol: symbol.to_string(),
211    }
212}
213
214#[cfg(test)]
215mod tests {
216    use objects::object::{Attribution, Blob, Principal, ProducerId, RiskSignalKind, SignalAnchor};
217    use tempfile::TempDir;
218
219    use super::*;
220
221    fn fresh_repo() -> (TempDir, Repository) {
222        let temp = TempDir::new().unwrap();
223        let repo = Repository::init_default(temp.path()).unwrap();
224        (temp, repo)
225    }
226
227    fn local_service(repo: Repository) -> LocalSignalService {
228        let dedup = std::sync::Arc::new(
229            repo::operation_dedup::OperationDedupStore::open(repo.heddle_dir()).unwrap(),
230        );
231        LocalSignalService::new(GrpcLocalService::new(std::sync::Arc::new(repo), dedup))
232    }
233
234    fn snapshot_with_signals(repo: &Repository, signals: Vec<RiskSignal>) -> ChangeId {
235        let attribution = Attribution::human(Principal::new("Alice", "alice@example.com"));
236        let snapshot = repo
237            .snapshot_with_attribution(Some("test snapshot".to_string()), None, attribution)
238            .unwrap();
239        let blob = RiskSignalBlob::new(signals).encode().unwrap();
240        let hash = repo.store().put_blob(&Blob::new(blob)).unwrap();
241        let state = repo
242            .store()
243            .get_state(&snapshot.change_id)
244            .unwrap()
245            .unwrap();
246        let updated = state.with_risk_signals(hash);
247        repo.store().put_state(&updated).unwrap();
248        snapshot.change_id
249    }
250
251    fn sample_signal(kind: RiskSignalKind, reason: &str) -> RiskSignal {
252        RiskSignal {
253            kind,
254            anchor: SignalAnchor::symbol("src/lib.rs", "foo"),
255            reason: reason.to_string(),
256            producer: ProducerId::new("novelty.tree_sitter", 1),
257            computed_at: 1_700_000_000,
258            computed_against: None,
259        }
260    }
261
262    #[tokio::test]
263    #[serial_test::serial(process_global)]
264    async fn compute_state_signals_returns_persisted_signals() {
265        let (_t, repo) = fresh_repo();
266        let signal = sample_signal(RiskSignalKind::Novelty, "novel control flow shape");
267        let state_id = snapshot_with_signals(&repo, vec![signal]);
268        let svc = local_service(repo);
269        let resp = svc
270            .compute_state_signals(Request::new(ComputeStateSignalsRequest {
271                repo_path: String::new(),
272                state_id: state_id.as_bytes().to_vec(),
273                prior_state_id: Vec::new(),
274            }))
275            .await
276            .unwrap();
277        let signals = resp.into_inner().signals;
278        assert_eq!(signals.len(), 1);
279        assert_eq!(signals[0].kind, "novelty");
280        assert_eq!(signals[0].reason, "novel control flow shape");
281        assert_eq!(signals[0].visibility, "visible");
282    }
283
284    #[tokio::test]
285    #[serial_test::serial(process_global)]
286    async fn compute_state_signals_returns_empty_when_state_has_no_signals() {
287        let (_t, repo) = fresh_repo();
288        let attribution = Attribution::human(Principal::new("Alice", "alice@example.com"));
289        let snap = repo
290            .snapshot_with_attribution(Some("plain".to_string()), None, attribution)
291            .unwrap();
292        let svc = local_service(repo);
293        let resp = svc
294            .compute_state_signals(Request::new(ComputeStateSignalsRequest {
295                repo_path: String::new(),
296                state_id: snap.change_id.as_bytes().to_vec(),
297                prior_state_id: Vec::new(),
298            }))
299            .await
300            .unwrap();
301        assert!(resp.into_inner().signals.is_empty());
302    }
303
304    #[tokio::test]
305    #[serial_test::serial(process_global)]
306    async fn invalid_state_id_returns_invalid_argument() {
307        let (_t, repo) = fresh_repo();
308        let svc = local_service(repo);
309        let err = svc
310            .compute_state_signals(Request::new(ComputeStateSignalsRequest {
311                repo_path: String::new(),
312                state_id: "not-a-change-id".into(),
313                prior_state_id: Vec::new(),
314            }))
315            .await
316            .unwrap_err();
317        assert_eq!(err.code(), tonic::Code::InvalidArgument);
318    }
319
320    #[tokio::test]
321    #[serial_test::serial(process_global)]
322    async fn signal_health_groups_by_module_id() {
323        let (_t, repo) = fresh_repo();
324        let novelty = sample_signal(RiskSignalKind::Novelty, "novel");
325        snapshot_with_signals(&repo, vec![novelty]);
326        let svc = local_service(repo);
327        let resp = svc
328            .get_repo_signal_health(Request::new(GetRepoSignalHealthRequest {
329                repo_path: String::new(),
330                window_states: 50,
331            }))
332            .await
333            .unwrap();
334        let report = resp.into_inner();
335        assert!(report.window_states >= 1);
336        let entry = report
337            .entries
338            .iter()
339            .find(|e| e.module_id == "novelty.tree_sitter")
340            .expect("novelty entry present");
341        assert!(entry.fire_rate > 0.0 && entry.fire_rate <= 1.0);
342    }
343}