1use std::{collections::HashMap, pin::Pin};
8
9use futures::Stream;
10use grpc::heddle::v1::{
11 ComputeStateSignalsRequest, ComputeStateSignalsResponse, GetRepoSignalHealthRequest,
12 PathSymbolRef, RepoSignalHealthReport, RiskSignal as ProtoRiskSignal,
13 SignalAnchor as ProtoSignalAnchor, SignalHealthEntry, SignalUpdate,
14 SubscribeSignalUpdatesRequest, signal_service_server::SignalService,
15};
16use objects::{
17 object::{ChangeId, RiskSignal, RiskSignalBlob, State},
18 store::ObjectStore,
19};
20use repo::Repository;
21use tokio_stream::wrappers::ReceiverStream;
22use tonic::{Request, Response, Status};
23
24use super::{GrpcLocalService, to_status};
25
26#[derive(Clone)]
27pub struct LocalSignalService {
28 inner: GrpcLocalService,
29}
30
31impl LocalSignalService {
32 pub fn new(inner: GrpcLocalService) -> Self {
33 Self { inner }
34 }
35
36 fn repo(&self) -> &Repository {
37 self.inner.repo()
38 }
39}
40
41#[tonic::async_trait]
42impl SignalService for LocalSignalService {
43 type SubscribeSignalUpdatesStream =
44 Pin<Box<dyn Stream<Item = Result<SignalUpdate, Status>> + Send>>;
45
46 async fn compute_state_signals(
47 &self,
48 request: Request<ComputeStateSignalsRequest>,
49 ) -> Result<Response<ComputeStateSignalsResponse>, Status> {
50 let req = request.into_inner();
51 if req.state_id.is_empty() {
52 return Err(Status::invalid_argument("state_id is required"));
53 }
54 let change_id = ChangeId::try_from_slice(&req.state_id)
55 .map_err(|err| Status::invalid_argument(format!("invalid state_id: {err}")))?;
56 let state = self
57 .repo()
58 .store()
59 .get_state(&change_id)
60 .map_err(to_status)?
61 .ok_or_else(|| Status::not_found(format!("state {change_id} not found")))?;
62 let signals = load_signals(self.repo(), &state)?;
67 let proto_signals = signals
68 .iter()
69 .map(|s| signal_to_proto(s, "visible"))
70 .collect();
71 Ok(Response::new(ComputeStateSignalsResponse {
72 signals: proto_signals,
73 tick_budget: 3,
74 }))
75 }
76
77 async fn get_repo_signal_health(
78 &self,
79 request: Request<GetRepoSignalHealthRequest>,
80 ) -> Result<Response<RepoSignalHealthReport>, Status> {
81 let req = request.into_inner();
82 let window = if req.window_states == 0 {
83 DEFAULT_HEALTH_WINDOW
84 } else {
85 req.window_states.min(MAX_HEALTH_WINDOW) as usize
86 };
87 let states = walk_recent_states(self.repo(), window).map_err(to_status)?;
91 let visited = states.len() as u32;
92 let mut per_module: HashMap<String, u32> = HashMap::new();
93 for state in &states {
94 let signals = load_signals(self.repo(), state)?;
95 let mut seen_modules = std::collections::HashSet::new();
99 for sig in &signals {
100 let key = sig.producer.module.clone();
101 if seen_modules.insert(key.clone()) {
102 *per_module.entry(key).or_insert(0) += 1;
103 }
104 }
105 }
106 let warn_threshold = 0.5_f32;
107 let entries = per_module
108 .into_iter()
109 .map(|(module_id, hit_count)| {
110 let fire_rate = if visited == 0 {
111 0.0
112 } else {
113 hit_count as f32 / visited as f32
114 };
115 SignalHealthEntry {
116 module_id,
117 fire_rate,
118 warn: fire_rate > warn_threshold,
119 }
120 })
121 .collect();
122 Ok(Response::new(RepoSignalHealthReport {
123 entries,
124 window_states: visited,
125 }))
126 }
127
128 async fn subscribe_signal_updates(
129 &self,
130 _request: Request<SubscribeSignalUpdatesRequest>,
131 ) -> Result<Response<Self::SubscribeSignalUpdatesStream>, Status> {
132 let (_tx, rx) = tokio::sync::mpsc::channel::<Result<SignalUpdate, Status>>(1);
137 Ok(Response::new(Box::pin(ReceiverStream::new(rx))))
138 }
139}
140
141const DEFAULT_HEALTH_WINDOW: usize = 200;
142const MAX_HEALTH_WINDOW: u32 = 5_000;
143
144fn load_signals(repo: &Repository, state: &State) -> Result<Vec<RiskSignal>, Status> {
145 let Some(hash) = state.risk_signals else {
146 return Ok(Vec::new());
147 };
148 let blob = repo
149 .store()
150 .get_blob(&hash)
151 .map_err(to_status)?
152 .ok_or_else(|| {
153 Status::data_loss(format!(
154 "risk_signals blob {hash} referenced by state {} is missing",
155 state.change_id
156 ))
157 })?;
158 let parsed = RiskSignalBlob::decode(blob.content()).map_err(|err| {
159 Status::internal(format!(
160 "failed to decode risk signals on state {}: {err}",
161 state.change_id
162 ))
163 })?;
164 Ok(parsed.signals)
165}
166
167fn walk_recent_states(repo: &Repository, window: usize) -> objects::error::Result<Vec<State>> {
168 let mut out = Vec::new();
169 let mut cursor = repo.head()?;
170 while let Some(id) = cursor {
171 if out.len() >= window {
172 break;
173 }
174 let Some(state) = repo.store().get_state(&id)? else {
175 break;
176 };
177 let parent = state.parents.first().copied();
178 out.push(state);
179 cursor = parent;
180 }
181 Ok(out)
182}
183
184fn signal_to_proto(sig: &RiskSignal, visibility: &str) -> ProtoRiskSignal {
185 let (start_line, end_line) = sig.anchor.line_range.unwrap_or((0, 0));
186 ProtoRiskSignal {
187 kind: sig.kind.as_str().to_string(),
188 anchor: Some(ProtoSignalAnchor {
189 file: sig.anchor.file.clone(),
190 symbol: sig.anchor.symbol.clone().unwrap_or_default(),
191 start_line,
192 end_line,
193 }),
194 reason: sig.reason.clone(),
195 producer_module: sig.producer.module.clone(),
196 producer_version: sig.producer.version,
197 computed_at: Some(prost_types::Timestamp {
198 seconds: sig.computed_at,
199 nanos: 0,
200 }),
201 visibility: visibility.to_string(),
202 }
203}
204
205#[allow(dead_code)]
207fn make_path_symbol(file: &str, symbol: &str) -> PathSymbolRef {
208 PathSymbolRef {
209 file: file.to_string(),
210 symbol: symbol.to_string(),
211 }
212}
213
214#[cfg(test)]
215mod tests {
216 use objects::object::{Attribution, Blob, Principal, ProducerId, RiskSignalKind, SignalAnchor};
217 use tempfile::TempDir;
218
219 use super::*;
220
221 fn fresh_repo() -> (TempDir, Repository) {
222 let temp = TempDir::new().unwrap();
223 let repo = Repository::init_default(temp.path()).unwrap();
224 (temp, repo)
225 }
226
227 fn local_service(repo: Repository) -> LocalSignalService {
228 let dedup = std::sync::Arc::new(
229 repo::operation_dedup::OperationDedupStore::open(repo.heddle_dir()).unwrap(),
230 );
231 LocalSignalService::new(GrpcLocalService::new(std::sync::Arc::new(repo), dedup))
232 }
233
234 fn snapshot_with_signals(repo: &Repository, signals: Vec<RiskSignal>) -> ChangeId {
235 let attribution = Attribution::human(Principal::new("Alice", "alice@example.com"));
236 let snapshot = repo
237 .snapshot_with_attribution(Some("test snapshot".to_string()), None, attribution)
238 .unwrap();
239 let blob = RiskSignalBlob::new(signals).encode().unwrap();
240 let hash = repo.store().put_blob(&Blob::new(blob)).unwrap();
241 let state = repo
242 .store()
243 .get_state(&snapshot.change_id)
244 .unwrap()
245 .unwrap();
246 let updated = state.with_risk_signals(hash);
247 repo.store().put_state(&updated).unwrap();
248 snapshot.change_id
249 }
250
251 fn sample_signal(kind: RiskSignalKind, reason: &str) -> RiskSignal {
252 RiskSignal {
253 kind,
254 anchor: SignalAnchor::symbol("src/lib.rs", "foo"),
255 reason: reason.to_string(),
256 producer: ProducerId::new("novelty.tree_sitter", 1),
257 computed_at: 1_700_000_000,
258 computed_against: None,
259 }
260 }
261
262 #[tokio::test]
263 #[serial_test::serial(process_global)]
264 async fn compute_state_signals_returns_persisted_signals() {
265 let (_t, repo) = fresh_repo();
266 let signal = sample_signal(RiskSignalKind::Novelty, "novel control flow shape");
267 let state_id = snapshot_with_signals(&repo, vec![signal]);
268 let svc = local_service(repo);
269 let resp = svc
270 .compute_state_signals(Request::new(ComputeStateSignalsRequest {
271 repo_path: String::new(),
272 state_id: state_id.as_bytes().to_vec(),
273 prior_state_id: Vec::new(),
274 }))
275 .await
276 .unwrap();
277 let signals = resp.into_inner().signals;
278 assert_eq!(signals.len(), 1);
279 assert_eq!(signals[0].kind, "novelty");
280 assert_eq!(signals[0].reason, "novel control flow shape");
281 assert_eq!(signals[0].visibility, "visible");
282 }
283
284 #[tokio::test]
285 #[serial_test::serial(process_global)]
286 async fn compute_state_signals_returns_empty_when_state_has_no_signals() {
287 let (_t, repo) = fresh_repo();
288 let attribution = Attribution::human(Principal::new("Alice", "alice@example.com"));
289 let snap = repo
290 .snapshot_with_attribution(Some("plain".to_string()), None, attribution)
291 .unwrap();
292 let svc = local_service(repo);
293 let resp = svc
294 .compute_state_signals(Request::new(ComputeStateSignalsRequest {
295 repo_path: String::new(),
296 state_id: snap.change_id.as_bytes().to_vec(),
297 prior_state_id: Vec::new(),
298 }))
299 .await
300 .unwrap();
301 assert!(resp.into_inner().signals.is_empty());
302 }
303
304 #[tokio::test]
305 #[serial_test::serial(process_global)]
306 async fn invalid_state_id_returns_invalid_argument() {
307 let (_t, repo) = fresh_repo();
308 let svc = local_service(repo);
309 let err = svc
310 .compute_state_signals(Request::new(ComputeStateSignalsRequest {
311 repo_path: String::new(),
312 state_id: "not-a-change-id".into(),
313 prior_state_id: Vec::new(),
314 }))
315 .await
316 .unwrap_err();
317 assert_eq!(err.code(), tonic::Code::InvalidArgument);
318 }
319
320 #[tokio::test]
321 #[serial_test::serial(process_global)]
322 async fn signal_health_groups_by_module_id() {
323 let (_t, repo) = fresh_repo();
324 let novelty = sample_signal(RiskSignalKind::Novelty, "novel");
325 snapshot_with_signals(&repo, vec![novelty]);
326 let svc = local_service(repo);
327 let resp = svc
328 .get_repo_signal_health(Request::new(GetRepoSignalHealthRequest {
329 repo_path: String::new(),
330 window_states: 50,
331 }))
332 .await
333 .unwrap();
334 let report = resp.into_inner();
335 assert!(report.window_states >= 1);
336 let entry = report
337 .entries
338 .iter()
339 .find(|e| e.module_id == "novelty.tree_sitter")
340 .expect("novelty entry present");
341 assert!(entry.fire_rate > 0.0 && entry.fire_rate <= 1.0);
342 }
343}