#!/usr/bin/env bash
set -euo pipefail

ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
cd "$ROOT_DIR"

ARTIFACTS_DIR="${ARTIFACTS_DIR:-$ROOT_DIR/artifacts}"
REPORT_PREFIX="${REPORT_PREFIX:-mixed_workload}"
SLO_GATE_SCRIPT="${SLO_GATE_SCRIPT:-$ROOT_DIR/scripts/slo_gate.sh}"
ALLOW_ZERO_TOTAL="${ALLOW_ZERO_TOTAL:-0}"
MAX_ERROR_RATE="${MAX_ERROR_RATE:-0.05}"
MAX_P95_MICROS="${MAX_P95_MICROS:-200000}"
MAX_P99_MICROS="${MAX_P99_MICROS:-300000}"
mkdir -p "$ARTIFACTS_DIR"

LOG_DIR="$ARTIFACTS_DIR/${REPORT_PREFIX}_logs"
mkdir -p "$LOG_DIR"
REPORT_JSON="$ARTIFACTS_DIR/${REPORT_PREFIX}_report.json"
REPORT_MD="$ARTIFACTS_DIR/${REPORT_PREFIX}_report.md"

TIMESTAMP_UTC="$(date -u +"%Y-%m-%dT%H:%M:%SZ")"
COMMIT_SHA="$(git rev-parse --short HEAD 2>/dev/null || echo "unknown")"

extract_json_number() {
  local key="$1"
  local file="$2"
  grep -E "\"$key\"[[:space:]]*:" "$file" | head -n1 | sed -E 's/.*: ([^,}]+).*/\1/'
}

percentile() {
  local p="$1"
  shift
  local values=("$@")
  local n="${#values[@]}"
  if [[ "$n" -eq 0 ]]; then
    echo 0
    return
  fi
  local sorted_str
  sorted_str="$(printf "%s\n" "${values[@]}" | sort -n)"
  local sorted=()
  while IFS= read -r line; do
    sorted+=("$line")
  done <<<"$sorted_str"
  local rank=$(( (p * n + 99) / 100 ))
  if [[ "$rank" -lt 1 ]]; then
    rank=1
  fi
  local idx=$((rank - 1))
  echo "${sorted[$idx]}"
}

now_ns() {
  perl -MTime::HiRes=time -e 'printf "%.0f\n", time() * 1000000000'
}

run_profile() {
  local workdir="$1"
  local profile="$2"
  local query_runs="$3"
  local ingest_count="$4"
  local ingest_rounds="$5"

  local fraud_latencies=()
  local reco_latencies=()
  local supply_latencies=()
  local all_latencies=()
  local query_total=0
  local query_errors=0

  local ingest_start_ns
  ingest_start_ns="$(now_ns)"
  local ingest_out
  ingest_out="$(
    cd "$workdir"
    "$ROOT_DIR/target/debug/ir" ingest-batch-edge-loop 500000 "$ingest_count" "$ingest_rounds" 1 "p4-${profile}" 2>&1
  )"
  local ingest_end_ns
  ingest_end_ns="$(now_ns)"
  local accepted
  local rejected
  accepted="$(echo "$ingest_out" | sed -nE 's/.*accepted=([0-9]+).*/\1/p' | head -n1)"
  rejected="$(echo "$ingest_out" | sed -nE 's/.*rejected=([0-9]+).*/\1/p' | head -n1)"
  if [[ -z "$accepted" || -z "$rejected" ]]; then
    echo "failed to parse ingest output for profile $profile: $ingest_out" >&2
    exit 1
  fi
  local ingest_elapsed_ns=$((ingest_end_ns - ingest_start_ns))
  local ingest_elapsed_sec
  ingest_elapsed_sec="$(awk -v ns="$ingest_elapsed_ns" 'BEGIN { printf "%.6f", ns / 1000000000.0 }')"
  local ingest_eps
  ingest_eps="$(awk -v events="$accepted" -v sec="$ingest_elapsed_sec" 'BEGIN { if (sec <= 0) { print "0.00"; } else { printf "%.2f", events / sec; } }')"

  local q_fraud='MATCH (n) WHERE vector.cosine(n.embedding, $vec) > 0.80 RETURN n LIMIT 20'
  local q_reco='MATCH (n) RETURN n LIMIT 25'
  local q_supply='MATCH (n) WHERE vector.cosine(n.embedding, $vec) > 0.20 RETURN n LIMIT 50'

  for run in $(seq 1 "$query_runs"); do
    local fraud_log="$LOG_DIR/${profile}_fraud_q${run}.log"
    local reco_log="$LOG_DIR/${profile}_reco_q${run}.log"
    local supply_log="$LOG_DIR/${profile}_supply_q${run}.log"
    if (
      cd "$workdir"
      "$ROOT_DIR/target/debug/ir" query "$q_fraud"
    ) >"$fraud_log" 2>&1; then
      local fraud_latency
      fraud_latency="$(extract_json_number "latency_micros" "$fraud_log" || true)"
      if [[ -n "$fraud_latency" ]]; then
        fraud_latencies+=("$fraud_latency")
        all_latencies+=("$fraud_latency")
      else
        fraud_latencies+=(0)
        query_errors=$((query_errors + 1))
      fi
    else
      fraud_latencies+=(0)
      query_errors=$((query_errors + 1))
    fi
    query_total=$((query_total + 1))

    if (
      cd "$workdir"
      "$ROOT_DIR/target/debug/ir" query "$q_reco"
    ) >"$reco_log" 2>&1; then
      local reco_latency
      reco_latency="$(extract_json_number "latency_micros" "$reco_log" || true)"
      if [[ -n "$reco_latency" ]]; then
        reco_latencies+=("$reco_latency")
        all_latencies+=("$reco_latency")
      else
        reco_latencies+=(0)
        query_errors=$((query_errors + 1))
      fi
    else
      reco_latencies+=(0)
      query_errors=$((query_errors + 1))
    fi
    query_total=$((query_total + 1))

    if (
      cd "$workdir"
      "$ROOT_DIR/target/debug/ir" query "$q_supply"
    ) >"$supply_log" 2>&1; then
      local supply_latency
      supply_latency="$(extract_json_number "latency_micros" "$supply_log" || true)"
      if [[ -n "$supply_latency" ]]; then
        supply_latencies+=("$supply_latency")
        all_latencies+=("$supply_latency")
      else
        supply_latencies+=(0)
        query_errors=$((query_errors + 1))
      fi
    else
      supply_latencies+=(0)
      query_errors=$((query_errors + 1))
    fi
    query_total=$((query_total + 1))
  done

  local fraud_p95 reco_p95 supply_p95
  local fraud_p99 reco_p99 supply_p99
  local query_p95 query_p99
  fraud_p95="$(percentile 95 "${fraud_latencies[@]}")"
  reco_p95="$(percentile 95 "${reco_latencies[@]}")"
  supply_p95="$(percentile 95 "${supply_latencies[@]}")"
  fraud_p99="$(percentile 99 "${fraud_latencies[@]}")"
  reco_p99="$(percentile 99 "${reco_latencies[@]}")"
  supply_p99="$(percentile 99 "${supply_latencies[@]}")"
  query_p95="$(percentile 95 "${all_latencies[@]}")"
  query_p99="$(percentile 99 "${all_latencies[@]}")"

  local metrics_file="$ARTIFACTS_DIR/${REPORT_PREFIX}_${profile}_metrics.prom"
  cat >"$metrics_file" <<EOF
# TYPE iridium_query_total counter
iridium_query_total $query_total
# TYPE iridium_query_errors counter
iridium_query_errors $query_errors
# TYPE iridium_query_p95_latency_micros gauge
iridium_query_p95_latency_micros $query_p95
# TYPE iridium_query_p99_latency_micros gauge
iridium_query_p99_latency_micros $query_p99
EOF

  local gate_json="$ARTIFACTS_DIR/${REPORT_PREFIX}_${profile}_slo_gate.json"
  local gate_md="$ARTIFACTS_DIR/${REPORT_PREFIX}_${profile}_slo_gate.md"
  local gate_status=0
  if [[ "$ALLOW_ZERO_TOTAL" -eq 1 ]]; then
    MAX_ERROR_RATE="$MAX_ERROR_RATE" \
    MAX_P95_MICROS="$MAX_P95_MICROS" \
    MAX_P99_MICROS="$MAX_P99_MICROS" \
    OUT_JSON="$gate_json" OUT_MD="$gate_md" \
      bash "$SLO_GATE_SCRIPT" "$metrics_file" --allow-zero-total >/dev/null || gate_status=$?
  else
    MAX_ERROR_RATE="$MAX_ERROR_RATE" \
    MAX_P95_MICROS="$MAX_P95_MICROS" \
    MAX_P99_MICROS="$MAX_P99_MICROS" \
    OUT_JSON="$gate_json" OUT_MD="$gate_md" \
      bash "$SLO_GATE_SCRIPT" "$metrics_file" >/dev/null || gate_status=$?
  fi

  echo "$profile|$query_runs|$ingest_count|$ingest_rounds|$accepted|$rejected|$query_total|$query_errors|$ingest_eps|$fraud_p95|$fraud_p99|$reco_p95|$reco_p99|$supply_p95|$supply_p99|$query_p95|$query_p99|$gate_status|$(basename "$metrics_file")|$(basename "$gate_json")"
}

echo "Building ir binary..."
cargo build --bin ir >/dev/null

WORKDIR="$(mktemp -d)"
trap 'rm -rf "$WORKDIR"' EXIT
echo "Preparing mixed-workload dataset in $WORKDIR..."

for id in $(seq 1 800); do
  n1=$((id + 1))
  n2=$((id + 2))
  (
    cd "$WORKDIR"
    "$ROOT_DIR/target/debug/ir" ingest-node "$id" 1 "$n1,$n2"
  ) >/dev/null
done

PROFILE_ROWS=()
PROFILE_ROWS+=("$(run_profile "$WORKDIR" "light" 10 400 2)")
PROFILE_ROWS+=("$(run_profile "$WORKDIR" "balanced" 15 800 3)")
PROFILE_ROWS+=("$(run_profile "$WORKDIR" "burst" 20 1200 4)")

JSON_WORKLOADS=""
MD_ROWS=""
overall_pass=true
for row in "${PROFILE_ROWS[@]}"; do
  IFS='|' read -r profile query_runs ingest_count ingest_rounds accepted rejected query_total query_errors ingest_eps \
    fraud_p95 fraud_p99 reco_p95 reco_p99 supply_p95 supply_p99 query_p95 query_p99 gate_status metrics_file gate_json <<<"$row"
  if [[ "$gate_status" -ne 0 ]]; then
    overall_pass=false
  fi
  if [[ -n "$JSON_WORKLOADS" ]]; then
    JSON_WORKLOADS+=","
  fi
  JSON_WORKLOADS+=$'\n    '
  JSON_WORKLOADS+="\"$profile\": {\"query_runs\": $query_runs, \"query_total\": $query_total, \"query_errors\": $query_errors, \"ingest_count\": $ingest_count, \"ingest_rounds\": $ingest_rounds, \"accepted\": $accepted, \"rejected\": $rejected, \"ingest_events_per_sec\": $ingest_eps, \"fraud_p95\": $fraud_p95, \"fraud_p99\": $fraud_p99, \"recommendation_p95\": $reco_p95, \"recommendation_p99\": $reco_p99, \"supply_chain_p95\": $supply_p95, \"supply_chain_p99\": $supply_p99, \"query_p95\": $query_p95, \"query_p99\": $query_p99, \"slo_gate_pass\": $( [[ "$gate_status" -eq 0 ]] && echo "true" || echo "false" ), \"metrics_file\": \"$metrics_file\", \"slo_gate_json\": \"$gate_json\"}"
  MD_ROWS+="- ${profile}: ingest_eps=${ingest_eps}, query(total=${query_total}, errors=${query_errors}, p95=${query_p95}, p99=${query_p99}), fraud(p95=${fraud_p95}, p99=${fraud_p99}), reco(p95=${reco_p95}, p99=${reco_p99}), supply(p95=${supply_p95}, p99=${supply_p99}), slo_gate=$( [[ "$gate_status" -eq 0 ]] && echo "PASS" || echo "FAIL" )"$'\n'
done

cat >"$REPORT_JSON" <<EOF
{
  "timestamp_utc": "$TIMESTAMP_UTC",
  "commit_sha": "$COMMIT_SHA",
  "profiles": {${JSON_WORKLOADS}
  },
  "thresholds": {
    "max_error_rate": $MAX_ERROR_RATE,
    "max_p95_latency_micros": $MAX_P95_MICROS,
    "max_p99_latency_micros": $MAX_P99_MICROS
  },
  "overall_pass": $overall_pass
}
EOF

cat >"$REPORT_MD" <<EOF
# Mixed-Workload Matrix Report

- Timestamp (UTC): $TIMESTAMP_UTC
- Commit: $COMMIT_SHA
- Report prefix: $REPORT_PREFIX

## Profile Results
$MD_ROWS
## Thresholds
- max_error_rate: $MAX_ERROR_RATE
- max_p95_latency_micros: $MAX_P95_MICROS
- max_p99_latency_micros: $MAX_P99_MICROS

## Overall
- overall_pass: $overall_pass
EOF

echo "Wrote:"
echo "  $REPORT_JSON"
echo "  $REPORT_MD"

if [[ "$overall_pass" != "true" ]]; then
  echo "mixed workload matrix failed thresholds" >&2
  exit 1
fi
