Skip to main content

chainsaw/
session.rs

1//! Session: owns a loaded dependency graph and exposes query methods.
2//!
3//! A [`Session`] is the primary interface for library consumers (CLI, REPL,
4//! language server). It wraps graph loading, entry resolution, and keeps the
5//! background cache-write handle alive for the duration of the session.
6
7use std::collections::HashMap;
8use std::path::{Path, PathBuf};
9use std::sync::Arc;
10use std::sync::atomic::{AtomicBool, Ordering};
11
12use notify::{RecommendedWatcher, RecursiveMode, Watcher};
13
14use crate::cache::{CacheWriteHandle, LOCKFILES};
15use crate::error::Error;
16use crate::graph::{EdgeId, EdgeKind, ModuleGraph, ModuleId, PackageInfo};
17use crate::loader;
18use crate::query::{self, ChainTarget, CutModule, DiffResult, TraceOptions, TraceResult};
19use crate::report::{
20    self, ChainReport, CutEntry, CutReport, DiffReport, ModuleEntry, PackageEntry,
21    PackageListEntry, PackagesReport, TraceReport,
22};
23
24/// The result of resolving a `--chain`/`--cut` argument against the graph.
25///
26/// The argument might be a file path (resolved to a [`ChainTarget::Module`])
27/// or a package name (resolved to a [`ChainTarget::Package`]).
28pub struct ResolvedTarget {
29    pub target: ChainTarget,
30    pub label: String,
31    pub exists: bool,
32}
33
34/// An open dependency-graph session.
35///
36/// Created via [`Session::open`], which loads (or builds) the graph and
37/// resolves the entry module. The background cache writer is joined on drop.
38pub struct Session {
39    graph: ModuleGraph,
40    reverse_adj: Vec<Vec<EdgeId>>,
41    root: PathBuf,
42    entry: PathBuf,
43    entry_id: ModuleId,
44    valid_extensions: &'static [&'static str],
45    from_cache: bool,
46    unresolvable_dynamic_count: usize,
47    unresolvable_dynamic_files: Vec<(PathBuf, usize)>,
48    file_warnings: Vec<String>,
49    _cache_handle: CacheWriteHandle,
50    dirty: Arc<AtomicBool>,
51    watcher: Option<RecommendedWatcher>,
52    cached_trace: Option<CachedTrace>,
53    cached_weights: Option<CachedWeights>,
54}
55
56/// Cached entry trace result, keyed on `(entry_id, include_dynamic)`.
57///
58/// The cache intentionally ignores `TraceOptions::top_n` and `ignore` because
59/// those fields only affect `heavy_packages` filtering — they don't change the
60/// underlying traversal (`static_weight`, `modules_by_cost`, `all_packages`).
61/// This is safe as long as callers use consistent options across cached calls
62/// (the REPL always uses `TraceOptions::default()`).
63struct CachedTrace {
64    entry_id: ModuleId,
65    include_dynamic: bool,
66    result: TraceResult,
67}
68
69struct CachedWeights {
70    entry_id: ModuleId,
71    include_dynamic: bool,
72    weights: Vec<u64>,
73}
74
75fn build_reverse_adj(graph: &ModuleGraph) -> Vec<Vec<EdgeId>> {
76    let mut rev = vec![Vec::new(); graph.module_count()];
77    for edge in &graph.edges {
78        rev[edge.to.0 as usize].push(edge.id);
79    }
80    rev
81}
82
83impl Session {
84    /// Load a dependency graph from `entry` and resolve the entry module.
85    ///
86    /// When `no_cache` is true the on-disk cache is bypassed entirely.
87    pub fn open(entry: &Path, no_cache: bool) -> Result<Self, Error> {
88        let (loaded, cache_handle) = loader::load_graph(entry, no_cache)?;
89
90        let entry_id = *loaded
91            .graph
92            .path_to_id
93            .get(&loaded.entry)
94            .ok_or_else(|| Error::EntryNotInGraph(loaded.entry.clone()))?;
95
96        let reverse_adj = build_reverse_adj(&loaded.graph);
97
98        Ok(Self {
99            graph: loaded.graph,
100            reverse_adj,
101            root: loaded.root,
102            entry: loaded.entry,
103            entry_id,
104            valid_extensions: loaded.valid_extensions,
105            from_cache: loaded.from_cache,
106            unresolvable_dynamic_count: loaded.unresolvable_dynamic_count,
107            unresolvable_dynamic_files: loaded.unresolvable_dynamic_files,
108            file_warnings: loaded.file_warnings,
109            _cache_handle: cache_handle,
110            dirty: Arc::new(AtomicBool::new(false)),
111            watcher: None,
112            cached_trace: None,
113            cached_weights: None,
114        })
115    }
116
117    /// Trace transitive import weight from the entry module.
118    pub fn trace(&self, opts: &TraceOptions) -> TraceResult {
119        query::trace(&self.graph, self.entry_id, opts)
120    }
121
122    /// Trace transitive import weight from an arbitrary file in the graph.
123    pub fn trace_from(
124        &self,
125        file: &Path,
126        opts: &TraceOptions,
127    ) -> Result<(TraceResult, PathBuf), Error> {
128        let canon = file
129            .canonicalize()
130            .or_else(|_| self.root.join(file).canonicalize())
131            .map_err(|e| Error::EntryNotFound(file.to_path_buf(), e))?;
132        let Some(&id) = self.graph.path_to_id.get(&canon) else {
133            return Err(Error::EntryNotInGraph(canon));
134        };
135        Ok((query::trace(&self.graph, id, opts), canon))
136    }
137
138    /// Resolve a chain/cut argument to a [`ChainTarget`].
139    ///
140    /// If the argument looks like a file path and resolves to a module in
141    /// the graph, returns `ChainTarget::Module`. Otherwise falls through
142    /// to a package name lookup.
143    pub fn resolve_target(&self, arg: &str) -> ResolvedTarget {
144        if looks_like_path(arg, self.valid_extensions)
145            && let Ok(target_path) = self.root.join(arg).canonicalize()
146            && let Some(&id) = self.graph.path_to_id.get(&target_path)
147        {
148            let p = &self.graph.module(id).path;
149            let label = p
150                .strip_prefix(&self.root)
151                .unwrap_or(p)
152                .to_string_lossy()
153                .into_owned();
154            return ResolvedTarget {
155                target: ChainTarget::Module(id),
156                label,
157                exists: true,
158            };
159        }
160        // File doesn't exist or isn't in the graph -- fall through to
161        // package name lookup. Handles packages like "six.py" or
162        // "highlight.js" whose names match file extensions.
163        let name = arg.to_string();
164        let exists = self.graph.package_map.contains_key(arg);
165        let label = name.clone();
166        ResolvedTarget {
167            target: ChainTarget::Package(name),
168            label,
169            exists,
170        }
171    }
172
173    /// Find all shortest import chains from the entry to a target.
174    pub fn chain(
175        &self,
176        target_arg: &str,
177        include_dynamic: bool,
178    ) -> (ResolvedTarget, Vec<Vec<ModuleId>>) {
179        let resolved = self.resolve_target(target_arg);
180        let chains = query::find_all_chains(
181            &self.graph,
182            self.entry_id,
183            &resolved.target,
184            include_dynamic,
185        );
186        (resolved, chains)
187    }
188
189    /// Find import chains and optimal cut points to sever them.
190    pub fn cut(
191        &mut self,
192        target_arg: &str,
193        top: i32,
194        include_dynamic: bool,
195    ) -> (ResolvedTarget, Vec<Vec<ModuleId>>, Vec<CutModule>) {
196        let resolved = self.resolve_target(target_arg);
197        let chains = query::find_all_chains(
198            &self.graph,
199            self.entry_id,
200            &resolved.target,
201            include_dynamic,
202        );
203        self.ensure_weights(include_dynamic);
204        let weights = &self
205            .cached_weights
206            .as_ref()
207            .expect("ensure_weights populates cache")
208            .weights;
209        let cuts = query::find_cut_modules(
210            &self.graph,
211            &chains,
212            self.entry_id,
213            &resolved.target,
214            top,
215            weights,
216        );
217        (resolved, chains, cuts)
218    }
219
220    /// Trace from a different entry point in the same graph and diff
221    /// against the current entry. Returns the diff and the canonical
222    /// path of the other entry (avoids redundant canonicalization by
223    /// the caller).
224    pub fn diff_entry(
225        &mut self,
226        other: &Path,
227        opts: &TraceOptions,
228    ) -> Result<(DiffResult, PathBuf), Error> {
229        let other_canon = other
230            .canonicalize()
231            .or_else(|_| self.root.join(other).canonicalize())
232            .map_err(|e| Error::EntryNotFound(other.to_path_buf(), e))?;
233        let Some(&other_id) = self.graph.path_to_id.get(&other_canon) else {
234            return Err(Error::EntryNotInGraph(other_canon.clone()));
235        };
236        self.ensure_trace(opts);
237        let snap_a = self
238            .cached_trace
239            .as_ref()
240            .expect("ensure_trace populates cache")
241            .result
242            .to_snapshot(&self.entry_label());
243        let snap_b = query::trace(&self.graph, other_id, opts)
244            .to_snapshot(&self.entry_label_for(&other_canon));
245        Ok((query::diff_snapshots(&snap_a, &snap_b), other_canon))
246    }
247
248    /// All third-party packages in the dependency graph.
249    pub fn packages(&self) -> &HashMap<String, PackageInfo> {
250        &self.graph.package_map
251    }
252
253    /// List direct imports of a file (outgoing edges).
254    pub fn imports(&self, file: &Path) -> Result<Vec<(PathBuf, EdgeKind)>, Error> {
255        let canon = file
256            .canonicalize()
257            .or_else(|_| self.root.join(file).canonicalize())
258            .map_err(|e| Error::EntryNotFound(file.to_path_buf(), e))?;
259        let Some(&id) = self.graph.path_to_id.get(&canon) else {
260            return Err(Error::EntryNotInGraph(canon));
261        };
262        let result = self
263            .graph
264            .outgoing_edges(id)
265            .iter()
266            .map(|&eid| {
267                let edge = self.graph.edge(eid);
268                (self.graph.module(edge.to).path.clone(), edge.kind)
269            })
270            .collect();
271        Ok(result)
272    }
273
274    /// List files that import a given file (reverse edge lookup).
275    pub fn importers(&self, file: &Path) -> Result<Vec<(PathBuf, EdgeKind)>, Error> {
276        let canon = file
277            .canonicalize()
278            .or_else(|_| self.root.join(file).canonicalize())
279            .map_err(|e| Error::EntryNotFound(file.to_path_buf(), e))?;
280        let Some(&id) = self.graph.path_to_id.get(&canon) else {
281            return Err(Error::EntryNotInGraph(canon));
282        };
283        let result = self.reverse_adj[id.0 as usize]
284            .iter()
285            .map(|&eid| {
286                let edge = self.graph.edge(eid);
287                (self.graph.module(edge.from).path.clone(), edge.kind)
288            })
289            .collect();
290        Ok(result)
291    }
292
293    /// Look up package info by name.
294    pub fn info(&self, package_name: &str) -> Option<&PackageInfo> {
295        self.graph.package_map.get(package_name)
296    }
297
298    /// Display label for the current entry point, including the project
299    /// directory name for disambiguation (e.g. `wrangler/src/index.ts`).
300    pub fn entry_label(&self) -> String {
301        self.entry_label_for(&self.entry)
302    }
303
304    /// Display label for an arbitrary path, relative to the project root.
305    pub fn entry_label_for(&self, path: &Path) -> String {
306        entry_label(path, &self.root)
307    }
308
309    /// Switch the default entry point to a different file in the graph.
310    ///
311    /// The file must already be in the graph (no rebuild). Accepts both
312    /// absolute paths and paths relative to the project root.
313    pub fn set_entry(&mut self, path: &Path) -> Result<(), Error> {
314        let canon = path
315            .canonicalize()
316            .or_else(|_| self.root.join(path).canonicalize())
317            .map_err(|e| Error::EntryNotFound(path.to_path_buf(), e))?;
318        let Some(&id) = self.graph.path_to_id.get(&canon) else {
319            return Err(Error::EntryNotInGraph(canon));
320        };
321        self.entry = canon;
322        self.entry_id = id;
323        self.invalidate_cache();
324        Ok(())
325    }
326
327    /// Start watching the project root for file changes.
328    ///
329    /// After calling this, `refresh()` will short-circuit when no relevant
330    /// files have changed since the last refresh. Idempotent: calling
331    /// `watch()` again replaces the existing watcher.
332    pub fn watch(&mut self) {
333        let dirty = Arc::clone(&self.dirty);
334        let extensions: Vec<String> = self
335            .valid_extensions
336            .iter()
337            .map(|&e| e.to_string())
338            .collect();
339
340        let handler = move |event: notify::Result<notify::Event>| {
341            if dirty.load(Ordering::Relaxed) {
342                return;
343            }
344            let Ok(event) = event else { return };
345            match event.kind {
346                notify::EventKind::Create(_)
347                | notify::EventKind::Modify(_)
348                | notify::EventKind::Remove(_) => {}
349                _ => return,
350            }
351            if event.paths.iter().any(|p| is_relevant_path(p, &extensions)) {
352                dirty.store(true, Ordering::Release);
353            }
354        };
355
356        if let Ok(mut watcher) = RecommendedWatcher::new(handler, notify::Config::default())
357            && watcher.watch(&self.root, RecursiveMode::Recursive).is_ok()
358        {
359            self.watcher = Some(watcher);
360        }
361    }
362
363    /// Whether the watcher has detected changes since the last refresh.
364    pub fn is_dirty(&self) -> bool {
365        self.dirty.load(Ordering::Acquire)
366    }
367
368    /// Check for file changes and rebuild the graph if needed.
369    ///
370    /// Returns `true` if the graph was updated (cold build or module count
371    /// changed since the last load).
372    #[allow(clippy::used_underscore_binding)] // _cache_handle held for drop
373    pub fn refresh(&mut self) -> Result<bool, Error> {
374        // Fast path: if a watcher is active and no relevant files changed,
375        // skip the full cache-hit path entirely.
376        if self.watcher.is_some() && !self.dirty.swap(false, Ordering::AcqRel) {
377            return Ok(false);
378        }
379
380        let (loaded, handle) = loader::load_graph(&self.entry, false)?;
381        let Some(&entry_id) = loaded.graph.path_to_id.get(&loaded.entry) else {
382            return Err(Error::EntryNotInGraph(loaded.entry));
383        };
384        // Detect structural change: cold build (not from cache) or module count
385        // changed. When from_cache is true and module count matches, edges are
386        // guaranteed identical (tier 1.5 only returns from_cache when imports
387        // are unchanged), so we can reuse the existing reverse adjacency index.
388        let changed =
389            !loaded.from_cache || loaded.graph.module_count() != self.graph.module_count();
390        if changed {
391            self.reverse_adj = build_reverse_adj(&loaded.graph);
392            self.invalidate_cache();
393        } else {
394            debug_assert_eq!(
395                self.reverse_adj,
396                build_reverse_adj(&loaded.graph),
397                "reverse_adj out of sync: cache reported unchanged but edges differ"
398            );
399        }
400        self.graph = loaded.graph;
401        self.root = loaded.root;
402        self.entry = loaded.entry;
403        self.entry_id = entry_id;
404        self.valid_extensions = loaded.valid_extensions;
405        self.from_cache = loaded.from_cache;
406        self.unresolvable_dynamic_count = loaded.unresolvable_dynamic_count;
407        self.unresolvable_dynamic_files = loaded.unresolvable_dynamic_files;
408        self.file_warnings = loaded.file_warnings;
409        self._cache_handle = handle;
410        Ok(changed)
411    }
412
413    // -- query cache --
414
415    fn invalidate_cache(&mut self) {
416        self.cached_trace = None;
417        self.cached_weights = None;
418    }
419
420    fn ensure_trace(&mut self, opts: &TraceOptions) {
421        let valid = self.cached_trace.as_ref().is_some_and(|c| {
422            c.entry_id == self.entry_id && c.include_dynamic == opts.include_dynamic
423        });
424        if !valid {
425            let result = query::trace(&self.graph, self.entry_id, opts);
426            self.cached_trace = Some(CachedTrace {
427                entry_id: self.entry_id,
428                include_dynamic: opts.include_dynamic,
429                result,
430            });
431        }
432    }
433
434    fn ensure_weights(&mut self, include_dynamic: bool) {
435        let valid = self
436            .cached_weights
437            .as_ref()
438            .is_some_and(|c| c.entry_id == self.entry_id && c.include_dynamic == include_dynamic);
439        if !valid {
440            let weights =
441                query::compute_exclusive_weights(&self.graph, self.entry_id, include_dynamic);
442            self.cached_weights = Some(CachedWeights {
443                entry_id: self.entry_id,
444                include_dynamic,
445                weights,
446            });
447        }
448    }
449
450    // -- report builders --
451
452    /// Trace and produce a display-ready report.
453    pub fn trace_report(&mut self, opts: &TraceOptions, top_modules: i32) -> TraceReport {
454        self.ensure_trace(opts);
455        let result = &self
456            .cached_trace
457            .as_ref()
458            .expect("ensure_trace populates cache")
459            .result;
460        build_trace_report(
461            result,
462            &self.entry,
463            &self.graph,
464            &self.root,
465            opts,
466            top_modules,
467        )
468    }
469
470    /// Trace from a different file and produce a display-ready report.
471    pub fn trace_from_report(
472        &self,
473        file: &Path,
474        opts: &TraceOptions,
475        top_modules: i32,
476    ) -> Result<(TraceReport, PathBuf), Error> {
477        let (result, canon) = self.trace_from(file, opts)?;
478        Ok((
479            build_trace_report(&result, &canon, &self.graph, &self.root, opts, top_modules),
480            canon,
481        ))
482    }
483
484    /// Find import chains and produce a display-ready report.
485    pub fn chain_report(&self, target_arg: &str, include_dynamic: bool) -> ChainReport {
486        let (resolved, chains) = self.chain(target_arg, include_dynamic);
487        ChainReport {
488            target: resolved.label,
489            found_in_graph: resolved.exists,
490            chain_count: chains.len(),
491            hop_count: chains.first().map_or(0, |c| c.len().saturating_sub(1)),
492            chains: chains
493                .iter()
494                .map(|chain| report::chain_display_names(&self.graph, chain, &self.root))
495                .collect(),
496        }
497    }
498
499    /// Find cut points and produce a display-ready report.
500    pub fn cut_report(&mut self, target_arg: &str, top: i32, include_dynamic: bool) -> CutReport {
501        let (resolved, chains, cuts) = self.cut(target_arg, top, include_dynamic);
502        CutReport {
503            target: resolved.label,
504            found_in_graph: resolved.exists,
505            chain_count: chains.len(),
506            direct_import: !chains.is_empty()
507                && cuts.is_empty()
508                && chains.iter().all(|c| c.len() == 2),
509            cut_points: cuts
510                .iter()
511                .map(|c| CutEntry {
512                    module: report::display_name(&self.graph, c.module_id, &self.root),
513                    exclusive_size_bytes: c.exclusive_size,
514                    chains_broken: c.chains_broken,
515                })
516                .collect(),
517        }
518    }
519
520    /// Diff two entry points and produce a display-ready report.
521    pub fn diff_report(
522        &mut self,
523        other: &Path,
524        opts: &TraceOptions,
525        limit: i32,
526    ) -> Result<DiffReport, Error> {
527        let (diff, other_canon) = self.diff_entry(other, opts)?;
528        let entry_a = self.entry_label();
529        let entry_b = self.entry_label_for(&other_canon);
530        Ok(DiffReport::from_diff(&diff, &entry_a, &entry_b, limit))
531    }
532
533    /// List packages and produce a display-ready report.
534    #[allow(clippy::cast_sign_loss)]
535    pub fn packages_report(&self, top: i32) -> PackagesReport {
536        let mut packages: Vec<_> = self.graph.package_map.values().collect();
537        packages.sort_by(|a, b| b.total_reachable_size.cmp(&a.total_reachable_size));
538        let total = packages.len();
539        let display_count = if top < 0 {
540            total
541        } else {
542            total.min(top as usize)
543        };
544
545        PackagesReport {
546            package_count: total,
547            packages: packages[..display_count]
548                .iter()
549                .map(|pkg| PackageListEntry {
550                    name: pkg.name.clone(),
551                    total_size_bytes: pkg.total_reachable_size,
552                    file_count: pkg.total_reachable_files,
553                })
554                .collect(),
555        }
556    }
557
558    // -- accessors --
559
560    pub fn graph(&self) -> &ModuleGraph {
561        &self.graph
562    }
563
564    pub fn root(&self) -> &Path {
565        &self.root
566    }
567
568    pub fn entry(&self) -> &Path {
569        &self.entry
570    }
571
572    pub fn entry_id(&self) -> ModuleId {
573        self.entry_id
574    }
575
576    pub fn valid_extensions(&self) -> &'static [&'static str] {
577        self.valid_extensions
578    }
579
580    pub fn from_cache(&self) -> bool {
581        self.from_cache
582    }
583
584    pub fn unresolvable_dynamic_count(&self) -> usize {
585        self.unresolvable_dynamic_count
586    }
587
588    pub fn unresolvable_dynamic_files(&self) -> &[(PathBuf, usize)] {
589        &self.unresolvable_dynamic_files
590    }
591
592    pub fn file_warnings(&self) -> &[String] {
593        &self.file_warnings
594    }
595}
596
597/// Build a display label for an entry point that includes the project
598/// directory name for disambiguation (e.g. `wrangler/src/index.ts`
599/// instead of just `src/index.ts`).
600pub fn entry_label(path: &Path, root: &Path) -> String {
601    let rel = path.strip_prefix(root).unwrap_or(path);
602    root.file_name().map_or_else(
603        || rel.to_string_lossy().into_owned(),
604        |name| Path::new(name).join(rel).to_string_lossy().into_owned(),
605    )
606}
607
608#[allow(clippy::cast_sign_loss)]
609fn build_trace_report(
610    result: &TraceResult,
611    entry_path: &Path,
612    graph: &ModuleGraph,
613    root: &Path,
614    opts: &TraceOptions,
615    top_modules: i32,
616) -> TraceReport {
617    let heavy_packages = result
618        .heavy_packages
619        .iter()
620        .map(|pkg| PackageEntry {
621            name: pkg.name.clone(),
622            total_size_bytes: pkg.total_size,
623            file_count: pkg.file_count,
624            chain: report::chain_display_names(graph, &pkg.chain, root),
625        })
626        .collect();
627
628    let display_count = if top_modules < 0 {
629        result.modules_by_cost.len()
630    } else {
631        result.modules_by_cost.len().min(top_modules as usize)
632    };
633    let modules_by_cost = result.modules_by_cost[..display_count]
634        .iter()
635        .map(|mc| ModuleEntry {
636            path: report::relative_path(&graph.module(mc.module_id).path, root),
637            exclusive_size_bytes: mc.exclusive_size,
638        })
639        .collect();
640
641    TraceReport {
642        entry: report::relative_path(entry_path, root),
643        static_weight_bytes: result.static_weight,
644        static_module_count: result.static_module_count,
645        dynamic_only_weight_bytes: result.dynamic_only_weight,
646        dynamic_only_module_count: result.dynamic_only_module_count,
647        heavy_packages,
648        modules_by_cost,
649        total_modules_with_cost: result.modules_by_cost.len(),
650        include_dynamic: opts.include_dynamic,
651        top: opts.top_n,
652    }
653}
654
655/// Determine whether a chain/cut argument looks like a file path
656/// (as opposed to a package name).
657pub fn looks_like_path(arg: &str, extensions: &[&str]) -> bool {
658    !arg.starts_with('@')
659        && (arg.contains('/')
660            || arg.contains(std::path::MAIN_SEPARATOR)
661            || arg
662                .rsplit_once('.')
663                .is_some_and(|(_, suffix)| extensions.contains(&suffix)))
664}
665
666/// Directories whose contents are never relevant to the dependency graph.
667const EXCLUDED_DIRS: &[&str] = &["node_modules", ".git", "__pycache__", ".chainsaw", "target"];
668
669/// Check whether a filesystem event path is relevant to the dependency graph.
670///
671/// Returns true for source files with matching extensions and lockfiles.
672/// Returns false for files inside excluded directories or with unrelated extensions.
673fn is_relevant_path<S: AsRef<str>>(path: &Path, valid_extensions: &[S]) -> bool {
674    // Reject paths inside excluded directories.
675    for component in path.components() {
676        if let std::path::Component::Normal(s) = component
677            && let Some(s) = s.to_str()
678            && EXCLUDED_DIRS.contains(&s)
679        {
680            return false;
681        }
682    }
683
684    // Accept lockfiles by filename.
685    if let Some(name) = path.file_name().and_then(|n| n.to_str())
686        && LOCKFILES.contains(&name)
687    {
688        return true;
689    }
690
691    // Accept source files by extension.
692    path.extension()
693        .and_then(|e| e.to_str())
694        .is_some_and(|ext| valid_extensions.iter().any(|e| e.as_ref() == ext))
695}
696
697#[cfg(test)]
698mod tests {
699    use super::*;
700
701    fn test_project() -> (tempfile::TempDir, PathBuf) {
702        let tmp = tempfile::tempdir().unwrap();
703        let root = tmp.path().canonicalize().unwrap();
704        std::fs::write(root.join("package.json"), r#"{"name":"test"}"#).unwrap();
705        let entry = root.join("index.ts");
706        std::fs::write(&entry, r#"import { x } from "./a";"#).unwrap();
707        std::fs::write(root.join("a.ts"), "export const x = 1;").unwrap();
708        (tmp, entry)
709    }
710
711    #[test]
712    fn open_and_trace() {
713        let (_tmp, entry) = test_project();
714        let session = Session::open(&entry, true).unwrap();
715        assert_eq!(session.graph().module_count(), 2);
716        let opts = TraceOptions::default();
717        let result = session.trace(&opts);
718        assert!(result.static_weight > 0);
719    }
720
721    #[test]
722    fn chain_finds_dependency() {
723        let (_tmp, entry) = test_project();
724        let session = Session::open(&entry, true).unwrap();
725        let (resolved, chains) = session.chain("a.ts", false);
726        assert!(resolved.exists);
727        assert!(!chains.is_empty());
728    }
729
730    #[test]
731    fn cut_finds_no_intermediate_on_direct_import() {
732        let (_tmp, entry) = test_project();
733        let mut session = Session::open(&entry, true).unwrap();
734        // index.ts -> a.ts is a 1-hop chain, no intermediate to cut
735        let (resolved, chains, cuts) = session.cut("a.ts", 10, false);
736        assert!(resolved.exists);
737        assert!(!chains.is_empty());
738        assert!(cuts.is_empty());
739    }
740
741    #[test]
742    fn diff_two_entries() {
743        let tmp = tempfile::tempdir().unwrap();
744        let root = tmp.path().canonicalize().unwrap();
745        std::fs::write(root.join("package.json"), r#"{"name":"test"}"#).unwrap();
746        // a.ts imports b.ts (so both are in the graph when built from a.ts)
747        // b.ts imports extra.ts (so tracing from b.ts has more weight)
748        let a = root.join("a.ts");
749        std::fs::write(&a, r#"import { foo } from "./b";"#).unwrap();
750        let b = root.join("b.ts");
751        std::fs::write(&b, r#"import { bar } from "./extra";"#).unwrap();
752        std::fs::write(root.join("extra.ts"), "export const y = 2;").unwrap();
753
754        let mut session = Session::open(&a, true).unwrap();
755        let (diff, _) = session.diff_entry(&b, &TraceOptions::default()).unwrap();
756        // b.ts trace (b + extra) should have less weight than a.ts trace (a + b + extra)
757        assert!(diff.entry_a_weight >= diff.entry_b_weight);
758    }
759
760    #[test]
761    fn packages_returns_package_map() {
762        let (_tmp, entry) = test_project();
763        let session = Session::open(&entry, true).unwrap();
764        // Test project has no third-party packages
765        assert!(session.packages().is_empty());
766    }
767
768    #[test]
769    fn resolve_target_file_path() {
770        let (_tmp, entry) = test_project();
771        let session = Session::open(&entry, true).unwrap();
772        let resolved = session.resolve_target("a.ts");
773        assert!(resolved.exists);
774        assert!(matches!(resolved.target, ChainTarget::Module(_)));
775    }
776
777    #[test]
778    fn resolve_target_missing_package() {
779        let (_tmp, entry) = test_project();
780        let session = Session::open(&entry, true).unwrap();
781        let resolved = session.resolve_target("nonexistent-pkg");
782        assert!(!resolved.exists);
783        assert!(matches!(resolved.target, ChainTarget::Package(_)));
784    }
785
786    #[test]
787    fn scoped_npm_package_is_not_path() {
788        let exts = &["ts", "tsx", "js", "jsx"];
789        assert!(!looks_like_path("@slack/web-api", exts));
790        assert!(!looks_like_path("@aws-sdk/client-s3", exts));
791        assert!(!looks_like_path("@anthropic-ai/sdk", exts));
792    }
793
794    #[test]
795    fn relative_file_path_is_path() {
796        let exts = &["ts", "tsx", "js", "jsx"];
797        assert!(looks_like_path("src/index.ts", exts));
798        assert!(looks_like_path("lib/utils.js", exts));
799    }
800
801    #[test]
802    fn bare_package_name_is_not_path() {
803        let exts = &["ts", "tsx", "js", "jsx"];
804        assert!(!looks_like_path("zod", exts));
805        assert!(!looks_like_path("express", exts));
806        // highlight.js is ambiguous — .js extension triggers path heuristic.
807        // resolve_target tries as file path first, falls back to package lookup.
808        assert!(looks_like_path("highlight.js", exts));
809    }
810
811    #[test]
812    fn file_with_extension_is_path() {
813        let exts = &["ts", "tsx", "js", "jsx", "py"];
814        assert!(looks_like_path("utils.ts", exts));
815        assert!(looks_like_path("main.py", exts));
816        assert!(!looks_like_path("utils.txt", exts));
817    }
818
819    #[test]
820    fn resolve_target_falls_back_to_package_for_extension_name() {
821        let (_tmp, entry) = test_project();
822        let session = Session::open(&entry, true).unwrap();
823        // "six.py" looks like a file (.py extension) but no such file exists,
824        // so it falls back to package name lookup.
825        let resolved = session.resolve_target("six.py");
826        assert!(!resolved.exists);
827        assert!(matches!(resolved.target, ChainTarget::Package(ref name) if name == "six.py"));
828    }
829
830    #[test]
831    fn imports_lists_direct_dependencies() {
832        let (_tmp, entry) = test_project();
833        let session = Session::open(&entry, true).unwrap();
834        let imports = session.imports(session.entry()).unwrap();
835        assert_eq!(imports.len(), 1);
836        assert!(imports[0].0.ends_with("a.ts"));
837        assert!(matches!(imports[0].1, EdgeKind::Static));
838    }
839
840    #[test]
841    fn importers_lists_reverse_dependencies() {
842        let (_tmp, entry) = test_project();
843        let session = Session::open(&entry, true).unwrap();
844        let a_path = session.root().join("a.ts");
845        let importers = session.importers(&a_path).unwrap();
846        assert_eq!(importers.len(), 1);
847        assert!(importers[0].0.ends_with("index.ts"));
848    }
849
850    #[test]
851    fn set_entry_switches_entry_point() {
852        let tmp = tempfile::tempdir().unwrap();
853        let root = tmp.path().canonicalize().unwrap();
854        std::fs::write(root.join("package.json"), r#"{"name":"test"}"#).unwrap();
855        let a = root.join("a.ts");
856        std::fs::write(&a, r#"import { x } from "./b";"#).unwrap();
857        let b = root.join("b.ts");
858        std::fs::write(&b, "export const x = 1;").unwrap();
859
860        let mut session = Session::open(&a, true).unwrap();
861        assert!(session.entry().ends_with("a.ts"));
862        session.set_entry(&b).unwrap();
863        assert!(session.entry().ends_with("b.ts"));
864        // Tracing from b: only b itself (no imports).
865        let result = session.trace(&crate::query::TraceOptions::default());
866        assert_eq!(result.static_module_count, 1);
867    }
868
869    #[test]
870    fn refresh_detects_file_change() {
871        let tmp = tempfile::tempdir().unwrap();
872        let root = tmp.path().canonicalize().unwrap();
873        std::fs::write(root.join("package.json"), r#"{"name":"test"}"#).unwrap();
874        let entry = root.join("index.ts");
875        std::fs::write(&entry, r#"import { x } from "./a";"#).unwrap();
876        std::fs::write(root.join("a.ts"), "export const x = 1;").unwrap();
877
878        let mut session = Session::open(&entry, true).unwrap();
879        assert_eq!(session.graph().module_count(), 2);
880
881        // Modify entry to add a new import; sleep for mtime granularity.
882        std::thread::sleep(std::time::Duration::from_millis(50));
883        std::fs::write(
884            &entry,
885            r#"import { x } from "./a"; import { y } from "./b";"#,
886        )
887        .unwrap();
888        std::fs::write(root.join("b.ts"), "export const y = 2;").unwrap();
889
890        let changed = session.refresh().unwrap();
891        assert!(changed);
892        assert_eq!(session.graph().module_count(), 3);
893    }
894
895    #[test]
896    fn event_filter_accepts_ts_source() {
897        let exts = &["ts", "tsx", "js", "jsx"];
898        assert!(is_relevant_path(Path::new("/project/src/index.ts"), exts));
899        assert!(is_relevant_path(Path::new("/project/lib/utils.jsx"), exts));
900    }
901
902    #[test]
903    fn event_filter_accepts_py_source() {
904        let exts = &["py"];
905        assert!(is_relevant_path(Path::new("/project/app/main.py"), exts));
906    }
907
908    #[test]
909    fn event_filter_rejects_wrong_extension() {
910        let exts = &["ts", "tsx", "js", "jsx"];
911        assert!(!is_relevant_path(Path::new("/project/README.md"), exts));
912        assert!(!is_relevant_path(Path::new("/project/image.png"), exts));
913        assert!(!is_relevant_path(Path::new("/project/Makefile"), exts));
914    }
915
916    #[test]
917    fn event_filter_rejects_excluded_dirs() {
918        let exts = &["ts", "tsx", "js", "jsx"];
919        assert!(!is_relevant_path(
920            Path::new("/project/node_modules/zod/index.ts"),
921            exts
922        ));
923        assert!(!is_relevant_path(
924            Path::new("/project/.git/objects/abc"),
925            exts
926        ));
927        assert!(!is_relevant_path(
928            Path::new("/project/__pycache__/mod.py"),
929            exts
930        ));
931        assert!(!is_relevant_path(
932            Path::new("/project/.chainsaw/cache"),
933            exts
934        ));
935        assert!(!is_relevant_path(
936            Path::new("/project/target/debug/build.rs"),
937            exts
938        ));
939    }
940
941    #[test]
942    fn event_filter_accepts_lockfiles() {
943        let exts = &["ts", "tsx", "js", "jsx"];
944        assert!(is_relevant_path(
945            Path::new("/project/package-lock.json"),
946            exts
947        ));
948        assert!(is_relevant_path(Path::new("/project/pnpm-lock.yaml"), exts));
949        assert!(is_relevant_path(Path::new("/project/yarn.lock"), exts));
950        assert!(is_relevant_path(Path::new("/project/bun.lockb"), exts));
951        assert!(is_relevant_path(Path::new("/project/poetry.lock"), exts));
952        assert!(is_relevant_path(Path::new("/project/Pipfile.lock"), exts));
953        assert!(is_relevant_path(Path::new("/project/uv.lock"), exts));
954        assert!(is_relevant_path(
955            Path::new("/project/requirements.txt"),
956            exts
957        ));
958    }
959
960    #[test]
961    fn event_filter_rejects_no_extension_non_lockfile() {
962        let exts = &["ts", "tsx", "js", "jsx"];
963        assert!(!is_relevant_path(Path::new("/project/Dockerfile"), exts));
964    }
965
966    #[test]
967    fn entry_label_includes_project_dir() {
968        let (_tmp, entry) = test_project();
969        let session = Session::open(&entry, true).unwrap();
970        let label = session.entry_label();
971        // The temp dir has a name, so label should be "dirname/index.ts"
972        assert!(label.ends_with("index.ts"));
973        assert!(label.contains('/'));
974    }
975
976    #[test]
977    fn trace_report_has_display_ready_fields() {
978        let (_tmp, entry) = test_project();
979        let mut session = Session::open(&entry, true).unwrap();
980        let opts = TraceOptions::default();
981        let report = session.trace_report(&opts, report::DEFAULT_TOP_MODULES);
982        assert!(report.entry.contains("index.ts"));
983        assert!(report.static_weight_bytes > 0);
984        assert_eq!(report.static_module_count, 2);
985        // No ModuleIds -- paths are strings
986        assert!(
987            report
988                .modules_by_cost
989                .iter()
990                .all(|m| m.path.contains(".ts"))
991        );
992    }
993
994    #[test]
995    fn chain_report_resolves_to_strings() {
996        let (_tmp, entry) = test_project();
997        let session = Session::open(&entry, true).unwrap();
998        let report = session.chain_report("a.ts", false);
999        assert!(report.found_in_graph);
1000        assert_eq!(report.chain_count, 1);
1001        assert!(report.chains[0].iter().any(|s| s.contains("a.ts")));
1002    }
1003
1004    #[test]
1005    fn cut_report_direct_import() {
1006        let (_tmp, entry) = test_project();
1007        let mut session = Session::open(&entry, true).unwrap();
1008        let report = session.cut_report("a.ts", 10, false);
1009        assert!(report.found_in_graph);
1010        assert_eq!(report.chain_count, 1);
1011        assert!(report.direct_import);
1012        assert!(report.cut_points.is_empty());
1013    }
1014
1015    #[test]
1016    fn cut_report_nonexistent_target() {
1017        let (_tmp, entry) = test_project();
1018        let mut session = Session::open(&entry, true).unwrap();
1019        let report = session.cut_report("nonexistent-pkg", 10, false);
1020        assert!(!report.found_in_graph);
1021        assert_eq!(report.chain_count, 0);
1022        assert!(!report.direct_import);
1023    }
1024
1025    #[test]
1026    fn packages_report_empty_for_first_party() {
1027        let (_tmp, entry) = test_project();
1028        let session = Session::open(&entry, true).unwrap();
1029        let report = session.packages_report(report::DEFAULT_TOP);
1030        assert_eq!(report.package_count, 0);
1031        assert!(report.packages.is_empty());
1032    }
1033
1034    #[test]
1035    fn watch_then_refresh_returns_false_when_clean() {
1036        let (_tmp, entry) = test_project();
1037        let mut session = Session::open(&entry, true).unwrap();
1038        session.watch();
1039        // No files changed — refresh should be instant and return false.
1040        let changed = session.refresh().unwrap();
1041        assert!(!changed);
1042    }
1043
1044    #[test]
1045    fn refresh_without_watch_still_works() {
1046        // Backward compat: no watch() call, refresh runs the full path.
1047        let (_tmp, entry) = test_project();
1048        let mut session = Session::open(&entry, false).unwrap();
1049        // Wait for cache write to complete, then refresh hits the cache.
1050        std::thread::sleep(std::time::Duration::from_millis(50));
1051        let changed = session.refresh().unwrap();
1052        assert!(!changed); // cache hit, nothing changed
1053    }
1054
1055    #[test]
1056    fn watch_detects_file_modification() {
1057        let tmp = tempfile::tempdir().unwrap();
1058        let root = tmp.path().canonicalize().unwrap();
1059        std::fs::write(root.join("package.json"), r#"{"name":"test"}"#).unwrap();
1060        let entry = root.join("index.ts");
1061        std::fs::write(&entry, r#"import { x } from "./a";"#).unwrap();
1062        std::fs::write(root.join("a.ts"), "export const x = 1;").unwrap();
1063
1064        let mut session = Session::open(&entry, true).unwrap();
1065        session.watch();
1066
1067        // Modify a source file.
1068        std::thread::sleep(std::time::Duration::from_millis(100));
1069        std::fs::write(root.join("a.ts"), "export const x = 2;").unwrap();
1070
1071        // Give the watcher time to deliver the event.
1072        std::thread::sleep(std::time::Duration::from_millis(200));
1073
1074        assert!(session.is_dirty());
1075        let _changed = session.refresh().unwrap();
1076        // After refresh, the flag is cleared regardless of changed return value.
1077        assert!(!session.is_dirty());
1078    }
1079
1080    #[test]
1081    fn cached_trace_invalidated_on_set_entry() {
1082        let tmp = tempfile::tempdir().unwrap();
1083        let root = tmp.path().canonicalize().unwrap();
1084        std::fs::write(root.join("package.json"), r#"{"name":"test"}"#).unwrap();
1085        let a = root.join("a.ts");
1086        std::fs::write(&a, r#"import { x } from "./b";"#).unwrap();
1087        let b = root.join("b.ts");
1088        std::fs::write(&b, "export const x = 1;").unwrap();
1089
1090        let mut session = Session::open(&a, true).unwrap();
1091        let opts = crate::query::TraceOptions::default();
1092
1093        let r1 = session.trace_report(&opts, 10);
1094        assert_eq!(r1.static_module_count, 2);
1095
1096        session.set_entry(&b).unwrap();
1097
1098        let r2 = session.trace_report(&opts, 10);
1099        assert_eq!(r2.static_module_count, 1);
1100    }
1101
1102    #[test]
1103    fn cached_trace_invalidated_on_refresh() {
1104        let tmp = tempfile::tempdir().unwrap();
1105        let root = tmp.path().canonicalize().unwrap();
1106        std::fs::write(root.join("package.json"), r#"{"name":"test"}"#).unwrap();
1107        let entry = root.join("index.ts");
1108        std::fs::write(&entry, r#"import { x } from "./a";"#).unwrap();
1109        std::fs::write(root.join("a.ts"), "export const x = 1;").unwrap();
1110
1111        let mut session = Session::open(&entry, true).unwrap();
1112        let opts = crate::query::TraceOptions::default();
1113
1114        let r1 = session.trace_report(&opts, 10);
1115        assert_eq!(r1.static_module_count, 2);
1116
1117        std::thread::sleep(std::time::Duration::from_millis(50));
1118        std::fs::write(
1119            &entry,
1120            r#"import { x } from "./a"; import { y } from "./b";"#,
1121        )
1122        .unwrap();
1123        std::fs::write(root.join("b.ts"), "export const y = 2;").unwrap();
1124
1125        let changed = session.refresh().unwrap();
1126        assert!(changed);
1127
1128        let r2 = session.trace_report(&opts, 10);
1129        assert_eq!(r2.static_module_count, 3);
1130    }
1131
1132    #[test]
1133    fn cut_uses_cached_exclusive_weights() {
1134        let tmp = tempfile::tempdir().unwrap();
1135        let root = tmp.path().canonicalize().unwrap();
1136        std::fs::write(root.join("package.json"), r#"{"name":"test"}"#).unwrap();
1137        let entry = root.join("entry.ts");
1138        std::fs::write(
1139            &entry,
1140            r#"import { a } from "./a"; import { b } from "./b";"#,
1141        )
1142        .unwrap();
1143        std::fs::write(
1144            root.join("a.ts"),
1145            r#"import { c } from "./c"; export const a = 1;"#,
1146        )
1147        .unwrap();
1148        std::fs::write(
1149            root.join("b.ts"),
1150            r#"import { c } from "./c"; export const b = 1;"#,
1151        )
1152        .unwrap();
1153        std::fs::write(
1154            root.join("c.ts"),
1155            r#"import { z } from "./node_modules/zod/index.js"; export const c = 1;"#,
1156        )
1157        .unwrap();
1158        std::fs::create_dir_all(root.join("node_modules/zod")).unwrap();
1159        std::fs::write(
1160            root.join("node_modules/zod/index.js"),
1161            "export const z = 1;",
1162        )
1163        .unwrap();
1164        std::fs::write(
1165            root.join("node_modules/zod/package.json"),
1166            r#"{"name":"zod"}"#,
1167        )
1168        .unwrap();
1169
1170        let mut session = Session::open(&entry, true).unwrap();
1171
1172        let opts = crate::query::TraceOptions::default();
1173        session.trace_report(&opts, 10);
1174
1175        let (_, chains, cuts) = session.cut("zod", 10, false);
1176        assert!(!chains.is_empty());
1177        assert!(
1178            cuts.iter()
1179                .any(|c| session.graph().module(c.module_id).path.ends_with("c.ts"))
1180        );
1181    }
1182
1183    /// Verify query cache produces measurable speedup.
1184    /// Run: `cargo test --lib session::tests::verify_cache_speedup -- --ignored --nocapture`
1185    #[test]
1186    #[ignore = "requires local wrangler checkout"]
1187    fn verify_cache_speedup() {
1188        use std::time::Instant;
1189
1190        let wrangler =
1191            Path::new("/Users/hlal/dev/cloudflare/workers-sdk/packages/wrangler/src/index.ts");
1192        if !wrangler.exists() {
1193            eprintln!("SKIP: wrangler not found");
1194            return;
1195        }
1196        let mut session = Session::open(wrangler, true).unwrap();
1197        let opts = crate::query::TraceOptions::default();
1198
1199        let t1 = Instant::now();
1200        let r1 = session.trace_report(&opts, 10);
1201        let first = t1.elapsed();
1202
1203        let t2 = Instant::now();
1204        let r2 = session.trace_report(&opts, 10);
1205        let second = t2.elapsed();
1206
1207        assert_eq!(r1.static_weight_bytes, r2.static_weight_bytes);
1208        assert_eq!(r1.static_module_count, r2.static_module_count);
1209
1210        eprintln!(
1211            "  first trace_report:  {:.0}us",
1212            first.as_secs_f64() * 1_000_000.0
1213        );
1214        eprintln!(
1215            "  second trace_report: {:.0}us",
1216            second.as_secs_f64() * 1_000_000.0
1217        );
1218        eprintln!(
1219            "  speedup: {:.1}x",
1220            first.as_secs_f64() / second.as_secs_f64()
1221        );
1222
1223        assert!(
1224            second < first / 3,
1225            "expected cache hit to be at least 3x faster: first={first:?}, second={second:?}"
1226        );
1227    }
1228}