proteus-engine 0.2.0

Advanced zero-day static analysis engine built with Rust and Python
Documentation
import requests
import json
import hashlib
import time
import io
import os
import argparse
from pathlib import Path
from typing import List, Dict, Optional, Any
from datetime import datetime

try:
    from dotenv import load_dotenv

    load_dotenv()
    HAS_DOTENV = True
except ImportError:
    HAS_DOTENV = False
    print(
        "[!] Warning: python-dotenv not installed. Install with: pip install python-dotenv"
    )

try:
    import pyzipper

    HAS_PYZIPPER = True
except ImportError:
    HAS_PYZIPPER = False
    print("[!] Warning: pyzipper not installed. Install with: pip install pyzipper")


class MalwareBazaarCollector:
    def __init__(self, api_key: str, output_dir: str = "dataset/malicious"):
        self.api_key = api_key
        self.api_url: str = "https://mb-api.abuse.ch/api/v1/"
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(parents=True, exist_ok=True)
        self.metadata_file = self.output_dir / "metadata.json"
        self.downloaded: Dict[str, Any] = self.load_metadata()

    def load_metadata(self) -> Dict[str, Any]:
        if self.metadata_file.exists():
            with open(self.metadata_file, "r") as f:
                data: Dict[str, Any] = json.load(f)
                return data
        return {}

    def save_metadata(self) -> None:
        with open(self.metadata_file, "w") as f:
            json.dump(self.downloaded, f, indent=2)

    def get_recent_samples(self, limit: int = 100) -> List[Dict[str, Any]]:
        data = {"query": "get_recent", "selector": str(limit)}
        headers = {"Auth-Key": self.api_key, "User-Agent": "Proteus/0.2.0"}

        try:
            response = requests.post(
                self.api_url, data=data, headers=headers, timeout=30
            )

            if response.status_code == 200:
                result: Dict[str, Any] = response.json()
                if result.get("query_status") == "ok":
                    return result.get("data", [])
                else:
                    print(f"[!] API Error: {result.get('query_status')}")
            else:
                print(f"[!] HTTP Error: {response.status_code}")
                print(f"[!] Response: {response.text[:200]}")
        except Exception as e:
            print(f"[!] Request failed: {e}")

        return []

    def get_samples_by_tag(self, tag: str, limit: int = 50) -> List[Dict[str, Any]]:
        data = {"query": "get_taginfo", "tag": tag, "limit": str(limit)}
        headers = {"Auth-Key": self.api_key, "User-Agent": "Proteus/0.2.0"}

        try:
            response = requests.post(
                self.api_url, data=data, headers=headers, timeout=30
            )

            if response.status_code == 200:
                result: Dict[str, Any] = response.json()
                if result.get("query_status") == "ok":
                    data_list = result.get("data", [])
                    return data_list
                else:
                    print(f"[!] Tag '{tag}' returned: {result.get('query_status')}")
            else:
                print(f"[!] HTTP Status Code: {response.status_code}")
                print(f"[!] Response text: {response.text[:500]}")
        except Exception as e:
            print(f"[!] Request EXCEPTION for tag {tag}: {e}")
            import traceback

            traceback.print_exc()

        return []

    def try_extract_zip(self, content: bytes, sha256_hash: str) -> Optional[bytes]:
        """Extract malware from password-protected ZIP files"""
        passwords = [b"infected", b"malware", b""]

        for password in passwords:
            try:
                if HAS_PYZIPPER:
                    with pyzipper.AESZipFile(io.BytesIO(content)) as zf:
                        if len(zf.namelist()) == 0:
                            continue
                        file_name = zf.namelist()[0]
                        extracted = zf.read(file_name, pwd=password)
                        return extracted
                else:
                    import zipfile

                    with zipfile.ZipFile(io.BytesIO(content)) as zf:
                        if len(zf.namelist()) == 0:
                            continue
                        file_name = zf.namelist()[0]
                        extracted = zf.read(file_name, pwd=password)
                        return extracted
            except Exception:
                continue

        return None

    def download_sample(self, sha256_hash: str) -> Optional[str]:
        if sha256_hash in self.downloaded:
            print(f"    [SKIP] Already downloaded: {sha256_hash[:16]}...")
            path: Optional[str] = self.downloaded[sha256_hash].get("path")
            return path

        data = {"query": "get_file", "sha256_hash": sha256_hash}
        headers = {"Auth-Key": self.api_key, "User-Agent": "Proteus/0.2.0"}

        try:
            response = requests.post(
                self.api_url, data=data, headers=headers, timeout=60
            )

            if response.status_code == 200 and len(response.content) > 100:
                content = response.content

                if content[:2] == b"PK":
                    print("    [*] ZIP detected, extracting...")
                    extracted = self.try_extract_zip(content, sha256_hash)
                    if extracted:
                        content = extracted
                        print("    [+] Successfully extracted malware")
                    else:
                        print(f"    [!] Failed to extract ZIP for {sha256_hash[:16]}")
                        return None

                actual_hash = hashlib.sha256(content).hexdigest()
                if actual_hash.lower() != sha256_hash.lower():
                    print(f"    [!] Hash mismatch for {sha256_hash[:16]}")
                    print(f"        Expected: {sha256_hash[:16]}...")
                    print(f"        Got: {actual_hash[:16]}...")
                    return None

                file_path = self.output_dir / f"{sha256_hash}.malware"
                with open(file_path, "wb") as f:
                    f.write(content)

                return str(file_path)
            else:
                print(
                    f"    [!] Download failed: {sha256_hash[:16]} (Status: {response.status_code})"
                )
        except Exception as e:
            print(f"    [!] Error downloading {sha256_hash[:16]}: {e}")

        return None

    def collect_by_tags(
        self, tags: List[str], samples_per_tag: int = 20
    ) -> Dict[str, Any]:
        all_samples: Dict[str, Any] = {}

        for tag in tags:
            print(f"\n[*] Fetching samples with tag: {tag}")
            samples = self.get_samples_by_tag(tag, limit=samples_per_tag)

            if samples:
                print(f"[+] Found {len(samples)} samples")

                for idx, sample in enumerate(samples[:samples_per_tag], 1):
                    sha256 = sample.get("sha256_hash")
                    signature = sample.get("signature", "unknown")
                    file_type = sample.get("file_type", "unknown")

                    if not sha256:
                        continue

                    print(f"[{idx}/{samples_per_tag}] {sha256[:16]}... ({signature})")

                    file_path = self.download_sample(sha256)

                    if file_path:
                        all_samples[sha256] = {
                            "path": file_path,
                            "sha256": sha256,
                            "type": file_type,
                            "signature": signature,
                            "tags": sample.get("tags", []),
                            "tag_searched": tag,
                            "first_seen": sample.get("first_seen", ""),
                            "downloaded_at": datetime.now().isoformat(),
                        }
                        self.downloaded[sha256] = all_samples[sha256]
                        self.save_metadata()

                        time.sleep(2)
            else:
                print(f"[!] No samples found for tag: {tag}")

        return all_samples

    def collect_recent(self, count: int = 50) -> Dict[str, Any]:
        print(f"\n[*] Fetching {count} recent samples from MalwareBazaar...")
        samples = self.get_recent_samples(limit=count)

        if not samples:
            print("[!] No recent samples returned")
            return {}

        print(f"[+] Found {len(samples)} samples")

        all_samples: Dict[str, Any] = {}
        for idx, sample in enumerate(samples[:count], 1):
            sha256 = sample.get("sha256_hash")
            signature = sample.get("signature", "unknown")
            file_type = sample.get("file_type", "unknown")

            if not sha256:
                continue

            print(f"[{idx}/{count}] {sha256[:16]}... ({signature})")

            file_path = self.download_sample(sha256)

            if file_path:
                all_samples[sha256] = {
                    "path": file_path,
                    "sha256": sha256,
                    "type": file_type,
                    "signature": signature,
                    "tags": sample.get("tags", []),
                    "first_seen": sample.get("first_seen", ""),
                    "downloaded_at": datetime.now().isoformat(),
                }
                self.downloaded[sha256] = all_samples[sha256]
                self.save_metadata()

                time.sleep(2)

        return all_samples

    def get_statistics(self) -> Dict[str, Any]:
        stats: Dict[str, Any] = {
            "total": len(self.downloaded),
            "by_signature": {},
            "by_type": {},
            "by_tag": {},
        }

        for sample in self.downloaded.values():
            sig = sample.get("signature", "unknown")
            ftype = sample.get("type", "unknown")

            sig_dict: Dict[str, int] = stats["by_signature"]
            sig_dict[sig] = sig_dict.get(sig, 0) + 1

            type_dict: Dict[str, int] = stats["by_type"]
            type_dict[ftype] = type_dict.get(ftype, 0) + 1

            for tag in sample.get("tags", []):
                tag_dict: Dict[str, int] = stats["by_tag"]
                tag_dict[tag] = tag_dict.get(tag, 0) + 1

        return stats


def main() -> None:
    parser = argparse.ArgumentParser(description="PROTEUS Malware Collector")
    parser.add_argument(
        "--samples",
        type=int,
        default=20,
        help="Number of samples to collect per tag (default: 20)",
    )
    args = parser.parse_args()

    print("===========================================")
    print("   PROTEUS Malware Collector")
    print("   Source: MalwareBazaar")
    print("===========================================\n")

    API_KEY = os.getenv("MALWAREBAZAAR_API_KEY")

    if not API_KEY:
        print("[!] ERROR: MALWAREBAZAAR_API_KEY not found in environment!")
        print("[!] Please set it in .env file or environment variables")
        return

    collector = MalwareBazaarCollector(api_key=API_KEY)

    malware_tags = [
        "ransomware",
        "trojan",
        "rat",
        "stealer",
        "backdoor",
        "loader",
        "miner",
        "banker",
        "spyware",
        "worm",
    ]

    samples_per_tag = args.samples
    total_target = samples_per_tag * len(malware_tags)

    print("[*] Collection Strategy:")
    print(f"    API Key: {API_KEY[:20]}...")
    print(f"    Tags: {', '.join(malware_tags)}")
    print(f"    Samples per tag: {samples_per_tag}")
    print(f"    Total target: ~{total_target} samples\n")

    input("Press ENTER to start collection (Ctrl+C to cancel)...")

    collector.collect_by_tags(malware_tags, samples_per_tag=samples_per_tag)

    print("\n==========================================")
    print("   Collection Statistics")
    print("==========================================")

    stats = collector.get_statistics()

    print(f"\n[+] Total samples: {stats['total']}")

    if stats["by_signature"]:
        print("\n[*] By Signature:")
        by_sig: Dict[str, int] = stats["by_signature"]
        for sig, count in sorted(by_sig.items(), key=lambda x: x[1], reverse=True)[:10]:
            print(f"    {sig}: {count}")

    if stats["by_type"]:
        print("\n[*] By File Type:")
        by_type: Dict[str, int] = stats["by_type"]
        for ftype, count in sorted(by_type.items(), key=lambda x: x[1], reverse=True):
            print(f"    {ftype}: {count}")

    if stats["by_tag"]:
        print("\n[*] By Tag:")
        by_tag: Dict[str, int] = stats["by_tag"]
        for tag, count in sorted(by_tag.items(), key=lambda x: x[1], reverse=True):
            print(f"    {tag}: {count}")

    print(f"\n[+] Samples saved to: {collector.output_dir}")
    print(f"[+] Metadata saved to: {collector.metadata_file}")


if __name__ == "__main__":
    main()